tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,298 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import field
16
+ from types import SimpleNamespace
17
+ from typing import Any, Tuple
18
+ from unittest.mock import MagicMock, patch
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ import pytest
24
+ from flax import nnx
25
+ from flax.typing import PRNGKey
26
+ from jax.sharding import Mesh
27
+ from vllm.config import ModelConfig
28
+
29
+ from tpu_inference.models.jax.llama4 import (Llama4ForCausalLM,
30
+ Llama4WeightLoader)
31
+
32
+
33
+ class MockParamLlama4:
34
+ """A mock for a parameter used in the Llama4 model."""
35
+ shape: Tuple[int, ...]
36
+ dtype: jnp.dtype = jnp.bfloat16
37
+ sharding_spec: Tuple[str | None, ...] | None = None
38
+ value: Any = field(init=False)
39
+ sharding: Any = field(init=False)
40
+
41
+ def __init__(self, shape=(32, 128)):
42
+ self.shape = shape
43
+ self.value = jnp.zeros(self.shape, dtype=self.dtype)
44
+ # The sharding spec is accessed during weight loading
45
+ self.sharding = SimpleNamespace(spec=self.sharding_spec)
46
+
47
+ # Allow the mock parameter's value to be updated
48
+ def __setattr__(self, name, value):
49
+ if name in ['value', 'shape', 'dtype', 'sharding', 'sharding_spec']:
50
+ self.__dict__[name] = value
51
+ else:
52
+ super().__setattr__(name, value)
53
+
54
+
55
+ class MockVllmConfig:
56
+ """A mock VllmConfig sufficient for testing the Llama4 model."""
57
+
58
+ def __init__(self,
59
+ model_name: str,
60
+ random_weights: bool = False,
61
+ tensor_parallelism: int = 1):
62
+ self.model_config = MagicMock(spec=ModelConfig)
63
+ self.load_config = MagicMock()
64
+ self.load_config.download_dir = None
65
+
66
+ # Choose small amount of layers to avoid OOM.
67
+ self.model_config.get_vocab_size.return_value = 202048
68
+ self.model_config.get_hidden_size.return_value = 32
69
+ self.model_config.model = model_name
70
+
71
+ self.additional_config = {
72
+ "random_weights": random_weights,
73
+ "sharding": {
74
+ "sharding_strategy": {
75
+ "tensor_parallelism": tensor_parallelism
76
+ }
77
+ }
78
+ }
79
+
80
+ self.cache_config = MagicMock(cache_dtype="auto")
81
+
82
+ text_config_mock = MagicMock()
83
+ text_config_mock.interleave_moe_layer_step = 1
84
+ text_config_mock.num_attention_heads = 40
85
+ text_config_mock.num_key_value_heads = 8
86
+ text_config_mock.head_dim = 128
87
+
88
+ hf_config_mock = MagicMock()
89
+ hf_config_mock.text_config = text_config_mock
90
+
91
+ self.model_config.hf_config = hf_config_mock
92
+
93
+
94
+ @pytest.fixture(scope="module")
95
+ def mesh():
96
+ """
97
+ Creates a mesh with all required axes for testing.
98
+ """
99
+ if not jax.devices():
100
+ pytest.skip("No JAX devices available for mesh creation.")
101
+
102
+ devices = np.array(jax.local_devices())
103
+ # Reshape devices into a 3D array to name 3 axes: data, model, and expert.
104
+ # The 'model' and 'expert' axes will have a size of 1.
105
+ num_devices = len(devices)
106
+ device_mesh = devices.reshape((num_devices, 1, 1, 1))
107
+
108
+ with Mesh(device_mesh,
109
+ axis_names=('data', 'attn_dp', 'model', 'expert')) as m:
110
+ yield m
111
+
112
+
113
+ @pytest.fixture
114
+ def rng() -> PRNGKey:
115
+ """Provides a reusable JAX PRNGKey."""
116
+ return jax.random.PRNGKey(42)
117
+
118
+
119
+ @pytest.fixture
120
+ def mock_vllm_config_llama4() -> MockVllmConfig:
121
+ return MockVllmConfig(model_name="meta-llama/Llama-4-Scout-17B-16E")
122
+
123
+
124
+ class TestLlama4ForCausalLM:
125
+ """Tests for the main LlamaForCausalLM model class."""
126
+
127
+ def test_init_llama4(self, mock_vllm_config_llama4, rng, mesh):
128
+ """Tests correct parameter detection for the Llama4 model variant."""
129
+ model = Llama4ForCausalLM(mock_vllm_config_llama4, rng, mesh)
130
+ assert model.hidden_size == 32
131
+ assert "llama-4" in model.vllm_config.model_config.model.lower()
132
+
133
+ def test_create_model_with_random_weights(self, mock_vllm_config_llama4,
134
+ rng, mesh):
135
+ """
136
+ Tests that random weight initialization creates concrete, non-zero-variance arrays.
137
+ """
138
+ with jax.set_mesh(mesh):
139
+ model = Llama4ForCausalLM(vllm_config=mock_vllm_config_llama4,
140
+ rng=rng,
141
+ mesh=mesh,
142
+ force_random_weights=True)
143
+ embedding_weight = model.embedder.input_embedding_table_VD.value
144
+ attention_q_kernel = model.layers[0].attn.kernel_q_proj_DNH.value
145
+ final_norm_scale = model.final_norm.scale.value
146
+
147
+ assert isinstance(embedding_weight, jax.Array)
148
+ assert isinstance(attention_q_kernel, jax.Array)
149
+ assert isinstance(final_norm_scale, jax.Array)
150
+
151
+ assert jnp.std(embedding_weight) > 0
152
+ assert jnp.std(attention_q_kernel) > 0
153
+
154
+ assert jnp.all(final_norm_scale == 1.0)
155
+
156
+ @patch("tpu_inference.models.jax.llama4.Llama4WeightLoader")
157
+ def test_load_weights_called_correctly(self, mock_loader_cls, rng, mesh):
158
+ """Tests that the weight loader is called correctly for checkpoint loading."""
159
+ vllm_config = MockVllmConfig(model_name="llama4-scout",
160
+ random_weights=False)
161
+ model = Llama4ForCausalLM(vllm_config, rng, mesh)
162
+
163
+ mock_loader_instance = MagicMock()
164
+ mock_loader_cls.return_value = mock_loader_instance
165
+ model.load_weights(rng)
166
+
167
+ mock_loader_cls.assert_called_once_with(vllm_config=vllm_config,
168
+ hidden_size=32,
169
+ attn_heads=40,
170
+ num_key_value_heads=8,
171
+ attn_head_dim=128)
172
+ mock_loader_instance.load_weights.assert_called_once_with(model)
173
+
174
+
175
+ class TestLlama4WeightLoader:
176
+ """Tests for the Llama4WeightLoader class."""
177
+
178
+ @pytest.fixture
179
+ def weight_loader(self):
180
+ # Patch the superclass's setup to isolate the Llama4 loader's logic
181
+ return Llama4WeightLoader(vllm_config=MockVllmConfig("test-model"),
182
+ hidden_size=32,
183
+ attn_heads=40,
184
+ num_key_value_heads=8,
185
+ attn_head_dim=128)
186
+
187
+ @pytest.mark.parametrize("hf_key, expected_num", [
188
+ ("language_model.model.layers.15.self_attn.q_proj.weight", 15),
189
+ ("layers.0.feed_forward.router.weight", 0),
190
+ ("language_model.model.layers.99.norm.weight", 99),
191
+ ("language_model.model.norm.weight", None),
192
+ ("language_model.model.embed_tokens.weight", None),
193
+ ])
194
+ def test_get_layer_num(self, weight_loader, hf_key, expected_num):
195
+ """Tests the private _get_layer_num utility function."""
196
+ assert weight_loader._get_layer_num(hf_key) == expected_num
197
+
198
+ @pytest.mark.parametrize("hf_key, expected_num", [
199
+ ("language_model.model.layers.10.feed_forward.experts.4.down_proj.weight",
200
+ 4),
201
+ ("language_model.model.layers.0.feed_forward.experts.0.gate_proj.weight_scale",
202
+ 0),
203
+ ("language_model.model.layers.5.feed_forward.experts.128.up_proj.weight",
204
+ 128),
205
+ ("language_model.model.norm.weight", None),
206
+ ("language_model.model.layers.15.self_attn.q_proj.weight", None),
207
+ ])
208
+ def test_get_expert_num(self, weight_loader, hf_key, expected_num):
209
+ """Tests the private _get_expert_num utility function to extract the expert index."""
210
+ assert weight_loader._get_expert_num(hf_key) == expected_num
211
+
212
+ @pytest.mark.parametrize("hf_key, expected", [
213
+ ("language_model.model.layers.15.self_attn.q_proj.weight",
214
+ "layers.15.attn.kernel_q_proj_DNH"),
215
+ ("language_model.model.layers.0.feed_forward.shared_expert.down_proj.weight",
216
+ "layers.0.shared_experts.kernel_down_proj_FD"),
217
+ ("language_model.model.embed_tokens.weight",
218
+ "embedder.input_embedding_table_VD"),
219
+ ("language_model.model.norm.weight", "final_norm.scale"),
220
+ ("language_model.lm_head.weight", "lm_head.input_embedding_table_DV"),
221
+ ("unmapped.key.name", "unmapped.key.name"),
222
+ ])
223
+ def test_map_loaded_to_standardized_name(self, weight_loader, hf_key,
224
+ expected):
225
+ """Tests the mapping from HuggingFace key names to internal names."""
226
+ assert weight_loader.map_loaded_to_standardized_name(
227
+ hf_key) == expected
228
+
229
+ def test_load_weights_transformation(self, weight_loader, rng, mesh):
230
+ """Tests that weights are correctly reshaped, transposed, and loaded."""
231
+ vllm_config = MockVllmConfig(model_name="llama4-small-test",
232
+ random_weights=False)
233
+
234
+ model = Llama4ForCausalLM(vllm_config, rng, mesh)
235
+
236
+ # Original weight shape is (vocab_size, hidden_size)
237
+ original_weight = jnp.ones((128, 32))
238
+ dummy_weights = [
239
+ ("language_model.model.embed_tokens.weight", original_weight),
240
+ ]
241
+ weight_loader.names_and_weights_generator = dummy_weights
242
+
243
+ # Mock get_param to return a mock param with the target shape (vocab_size, hidden_size)
244
+ mock_param = MockParamLlama4(shape=(128, 32))
245
+
246
+ with patch("tpu_inference.models.jax.llama4.get_param", return_value=mock_param), \
247
+ patch("tpu_inference.models.jax.llama4.shard_put", return_value=jnp.ones(mock_param.value.shape)) as mock_shard_put:
248
+
249
+ # This will now pass after the code fix
250
+ weight_loader.load_weights(model)
251
+
252
+ # Assert that shard_put was called with the correctly transposed weight
253
+ mock_shard_put.assert_called_once()
254
+
255
+ # Get the actual array passed to shard_put
256
+ called_with_weight = mock_shard_put.call_args[0][0]
257
+
258
+ # Check if the shape of the array passed to shard_put matches the model's expected shape.
259
+ assert called_with_weight.shape == mock_param.value.shape
260
+
261
+ def test_map_llama4_gate_up_proj(self, weight_loader, rng, mesh):
262
+ """Tests that gate_up_proj weights are correctly split, reshaped, transposed, and loaded."""
263
+ # Set up a dummy model and its config
264
+ model = Llama4ForCausalLM(MockVllmConfig("test-model"), rng, mesh)
265
+
266
+ # Create a dummy fused gate_up_proj weight tensor
267
+ hidden_size = 32
268
+ intermediate_size_moe = 8192
269
+ num_local_experts = 2
270
+ dummy_weight = jnp.ones(
271
+ (num_local_experts, hidden_size, 2 * intermediate_size_moe))
272
+
273
+ # Set up mocks and patches
274
+ mock_model_params = nnx.state(model)
275
+ mock_param = MockParamLlama4(shape=(2, hidden_size,
276
+ intermediate_size_moe))
277
+
278
+ # Create a dummy WeightLoader and set up the necessary attributes
279
+ weight_loader.is_verbose = False
280
+ layer_num = 0
281
+ weight_loader.names_and_weights_generator = [
282
+ (f"language_model.model.layers.{layer_num}.feed_forward.experts.gate_up_proj.weight",
283
+ dummy_weight),
284
+ ]
285
+
286
+ with patch("tpu_inference.models.jax.llama4.get_param", return_value=mock_param), \
287
+ patch("tpu_inference.models.jax.llama4.shard_put", return_value=jnp.ones(mock_param.value.shape)) as mock_shard_put:
288
+
289
+ # Call _map_llama4_gate_up_proj directly
290
+ weight_loader._map_llama4_gate_up_proj(
291
+ model, mock_model_params,
292
+ f"language_model.model.layers.{layer_num}.feed_forward.experts.gate_up_proj.weight",
293
+ dummy_weight)
294
+ # Check if shard_put was called the correct number of times and with the correct weight shapes
295
+ assert mock_shard_put.call_count == 2
296
+ # call_args_list gives you a list of all the calls with their arguments.
297
+ for call in mock_shard_put.call_args_list:
298
+ assert call[0][0].shape == (num_local_experts, 32, 8192)
@@ -0,0 +1,197 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock, patch
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pytest
21
+ from flax import nnx
22
+ from flax.typing import PRNGKey
23
+ from jax.sharding import Mesh
24
+ from vllm.config import ModelConfig
25
+
26
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.models.jax.llama_eagle3 import (Eagle3LlamaDecoderLayer,
28
+ EagleLlama3ForCausalLM)
29
+ from tpu_inference.runner.kv_cache import create_kv_caches
30
+
31
+
32
+ class MockSpeculativeConfig:
33
+
34
+ def __init__(self):
35
+ self.num_speculative_tokens = 3
36
+ self.method = "eagle3"
37
+ self.draft_model_config = None
38
+
39
+
40
+ class MockVllmConfig:
41
+
42
+ def __init__(self, model: str, draft_model: str, kv_cache_dtype):
43
+ self.model_config = ModelConfig(model)
44
+ self.model_config.dtype = jnp.bfloat16
45
+ self.load_config = MagicMock()
46
+ self.load_config.download_dir = None
47
+ self.speculative_config = MockSpeculativeConfig()
48
+ self.speculative_config.draft_model_config = ModelConfig(
49
+ draft_model,
50
+ dtype="bfloat16",
51
+ max_model_len=2048,
52
+ skip_tokenizer_init=True,
53
+ trust_remote_code=True)
54
+ self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
55
+
56
+
57
+ @pytest.fixture
58
+ def mock_vllm_config() -> MockVllmConfig:
59
+ return MockVllmConfig(model="meta-llama/Meta-Llama-3-8B-Instruct",
60
+ draft_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
61
+ kv_cache_dtype="auto")
62
+
63
+
64
+ @pytest.fixture(scope="module")
65
+ def mesh():
66
+ """Creates a mesh with 1 device."""
67
+ if not jax.devices():
68
+ pytest.skip("No JAX devices available for mesh creation.")
69
+
70
+ devices = np.array(jax.local_devices()[:1])
71
+ device_mesh = devices.reshape((1, 1, -1))
72
+
73
+ with Mesh(device_mesh, axis_names=('data', 'attn_dp', 'model')) as m:
74
+ yield m
75
+
76
+
77
+ @pytest.fixture
78
+ def mock_model_inputs(mock_vllm_config: MockVllmConfig):
79
+ """Provides mock inputs for the EagleLlama3 model."""
80
+ batch_size = 2
81
+ seq_len = 16
82
+ target_hidden_size = mock_vllm_config.model_config.get_hidden_size()
83
+
84
+ input_ids = jnp.ones((batch_size * seq_len, ), dtype=jnp.int32)
85
+ hidden_states = jnp.ones((batch_size * seq_len, target_hidden_size),
86
+ dtype=jnp.bfloat16)
87
+ attention_metadata = AttentionMetadata(
88
+ input_positions=jnp.arange(batch_size * seq_len, dtype=jnp.int32),
89
+ block_tables=jnp.zeros((batch_size, 1), dtype=jnp.int32).reshape(-1),
90
+ seq_lens=jnp.full((batch_size, ), seq_len, dtype=jnp.int32),
91
+ query_start_loc=jnp.arange(0, (batch_size + 1) * seq_len,
92
+ seq_len,
93
+ dtype=jnp.int32),
94
+ request_distribution=jnp.array([0, 0, batch_size], dtype=jnp.int32),
95
+ )
96
+ return input_ids, hidden_states, attention_metadata
97
+
98
+
99
+ @pytest.fixture
100
+ def rng() -> PRNGKey:
101
+ """Provides a reusable JAX PRNGKey."""
102
+ return jax.random.PRNGKey(42)
103
+
104
+
105
+ class TestEagleLlama3ForCausalLM:
106
+ """Tests for the EagleLlama3ForCausalLM model."""
107
+
108
+ def test_eagle3_decoder_layer_init(self, mock_vllm_config: MockVllmConfig,
109
+ rng: PRNGKey, mesh: Mesh):
110
+ """Tests the initialization of the Eagle3LlamaDecoderLayer."""
111
+ hf_config = mock_vllm_config.speculative_config.draft_model_config.hf_config
112
+ dtype = jnp.bfloat16
113
+ rngs = nnx.Rngs(rng)
114
+
115
+ layer = Eagle3LlamaDecoderLayer(
116
+ hf_config,
117
+ dtype,
118
+ rngs,
119
+ mesh,
120
+ kv_cache_dtype=mock_vllm_config.cache_config.cache_dtype)
121
+
122
+ # Check if projection layers are overridden with correct input size
123
+ original_hidden_size = hf_config.hidden_size
124
+ expected_input_size = 2 * original_hidden_size
125
+
126
+ assert layer.self_attn.q_proj.kernel.value.shape[
127
+ 0] == expected_input_size
128
+ assert layer.self_attn.k_proj.kernel.value.shape[
129
+ 0] == expected_input_size
130
+ assert layer.self_attn.v_proj.kernel.value.shape[
131
+ 0] == expected_input_size
132
+ assert isinstance(layer.hidden_norm, nnx.RMSNorm)
133
+
134
+ @pytest.mark.parametrize("mock_vllm_config", [
135
+ MockVllmConfig("meta-llama/Meta-Llama-3-8B-Instruct",
136
+ "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "auto"),
137
+ MockVllmConfig("meta-llama/Meta-Llama-3-8B-Instruct",
138
+ "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "fp8"),
139
+ ])
140
+ def test_forward_pass(self, mock_vllm_config: MockVllmConfig, rng: PRNGKey,
141
+ mesh: Mesh, mock_model_inputs):
142
+ """Tests the forward pass of the EagleLlama3ForCausalLM model."""
143
+
144
+ draft_model_config = mock_vllm_config.speculative_config.draft_model_config
145
+ hf_config = draft_model_config.hf_config
146
+ model = EagleLlama3ForCausalLM(mock_vllm_config, rng, mesh)
147
+
148
+ input_ids, hidden_states, attention_metadata = mock_model_inputs
149
+
150
+ kv_caches = create_kv_caches(
151
+ num_blocks=4,
152
+ block_size=16,
153
+ num_kv_heads=hf_config.num_key_value_heads,
154
+ head_size=hf_config.hidden_size // hf_config.num_attention_heads,
155
+ mesh=mesh,
156
+ layer_names=["layer"] * hf_config.num_hidden_layers,
157
+ cache_dtype=jnp.float8_e4m3fn
158
+ if mock_vllm_config.cache_config.cache_dtype == "fp8" else
159
+ jnp.bfloat16)
160
+
161
+ _, output_hidden_states, aux_hidden_states = model(
162
+ kv_caches, input_ids, hidden_states, attention_metadata)
163
+
164
+ logits = model.compute_logits(output_hidden_states)
165
+
166
+ target_model_config = mock_vllm_config.model_config
167
+
168
+ assert output_hidden_states.shape == (
169
+ input_ids.shape[0], draft_model_config.get_hidden_size())
170
+ assert logits.shape == (input_ids.shape[0],
171
+ target_model_config.get_vocab_size())
172
+ assert len(aux_hidden_states) == 1
173
+ assert aux_hidden_states[0].shape == output_hidden_states.shape
174
+
175
+ @patch("tpu_inference.models.jax.llama_eagle3.load_hf_weights")
176
+ def test_load_weights(self, mock_load_hf_weights: MagicMock,
177
+ mock_vllm_config: MockVllmConfig, rng: PRNGKey,
178
+ mesh: Mesh):
179
+ """Tests that the load_weights function is called correctly."""
180
+ model = EagleLlama3ForCausalLM(mock_vllm_config, rng, mesh)
181
+ model.load_weights(rng)
182
+
183
+ mock_load_hf_weights.assert_called_once()
184
+ call_args = mock_load_hf_weights.call_args.kwargs
185
+
186
+ assert call_args["vllm_config"] is mock_vllm_config
187
+ assert call_args["model"] is model
188
+ assert call_args["mesh"] is mesh
189
+ assert call_args["is_draft_model"] is True
190
+
191
+ metadata_map = call_args["metadata_map"]
192
+ assert "midlayer.hidden_norm" in metadata_map.name_map
193
+ assert "lm_head" in metadata_map.name_map
194
+ assert "d2t" in metadata_map.name_map
195
+ assert "q_proj" in metadata_map.reshape_map
196
+ assert metadata_map.reshape_map["q_proj"][-1] == (
197
+ 2 * mock_vllm_config.model_config.get_hidden_size())