tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,242 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import field
16
+ from types import SimpleNamespace
17
+ from typing import Any, Tuple
18
+ from unittest.mock import MagicMock, patch
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ import pytest
24
+ from flax.typing import PRNGKey
25
+ from jax.sharding import Mesh
26
+ from vllm.config import ModelConfig
27
+
28
+ from tpu_inference.models.jax.llama_guard_4 import (LlamaGuard4ForCausalLM,
29
+ LlamaGuard4WeightLoader)
30
+
31
+
32
+ class MockParamLlamaGuard4:
33
+ """A mock for a parameter used in the LlamaGuard4 model."""
34
+ shape: Tuple[int, ...]
35
+ dtype: jnp.dtype = jnp.bfloat16
36
+ sharding_spec: Tuple[str | None, ...] | None = None
37
+ value: Any = field(init=False)
38
+ sharding: Any = field(init=False)
39
+
40
+ def __init__(self, shape=(32, 128)):
41
+ self.shape = shape
42
+ self.value = jnp.zeros(self.shape, dtype=self.dtype)
43
+ # The sharding spec is accessed during weight loading
44
+ self.sharding = SimpleNamespace(spec=self.sharding_spec)
45
+
46
+ # Allow the mock parameter's value to be updated
47
+ def __setattr__(self, name, value):
48
+ if name in ['value', 'shape', 'dtype', 'sharding', 'sharding_spec']:
49
+ self.__dict__[name] = value
50
+ else:
51
+ super().__setattr__(name, value)
52
+
53
+
54
+ class MockVllmConfig:
55
+ """A mock VllmConfig sufficient for testing the LlamaGuard4 model."""
56
+
57
+ def __init__(self,
58
+ model_name: str,
59
+ random_weights: bool = False,
60
+ tensor_parallelism: int = 1):
61
+ self.model_config = MagicMock(spec=ModelConfig)
62
+ self.load_config = MagicMock()
63
+ self.load_config.download_dir = None
64
+
65
+ # Downsizing the following to avoid OOM
66
+ self.model_config.get_vocab_size.return_value = 1024
67
+ self.model_config.get_hidden_size.return_value = 128
68
+ self.model_config.model = model_name
69
+
70
+ self.additional_config = {
71
+ "random_weights": random_weights,
72
+ "sharding": {
73
+ "sharding_strategy": {
74
+ "tensor_parallelism": tensor_parallelism
75
+ }
76
+ }
77
+ }
78
+
79
+ self.cache_config = MagicMock(cache_dtype="auto")
80
+
81
+ # Mock the underlying HF config values for parameter detection
82
+ # Downsized to avoid OOM
83
+ text_config_mock = MagicMock()
84
+ text_config_mock.num_attention_heads = 4
85
+ text_config_mock.num_key_value_heads = 2
86
+ text_config_mock.head_dim = 32
87
+
88
+ hf_config_mock = MagicMock()
89
+ hf_config_mock.text_config = text_config_mock
90
+
91
+ self.model_config.hf_config = hf_config_mock
92
+
93
+
94
+ @pytest.fixture(scope="module")
95
+ def mesh():
96
+ """
97
+ Creates a mesh with all required axes for testing.
98
+ """
99
+ if not jax.devices():
100
+ pytest.skip("No JAX devices available for mesh creation.")
101
+
102
+ devices = np.array(jax.local_devices())
103
+
104
+ num_devices = len(devices)
105
+ device_mesh = devices.reshape((num_devices, 1, 1, 1))
106
+
107
+ with Mesh(device_mesh,
108
+ axis_names=('data', 'attn_dp', 'model', 'expert')) as m:
109
+ yield m
110
+
111
+
112
+ @pytest.fixture
113
+ def rng() -> PRNGKey:
114
+ """Provides a reusable JAX PRNGKey."""
115
+ return jax.random.PRNGKey(42)
116
+
117
+
118
+ @pytest.fixture
119
+ def mock_vllm_config_llama_guard_4() -> MockVllmConfig:
120
+ return MockVllmConfig(model_name="meta-llama/Llama-Guard-4-12B")
121
+
122
+
123
+ class TestLlamaGuard4ForCausalLM:
124
+ """Tests for the main LlamaGuard4ForCausalLM model class."""
125
+
126
+ def test_init_llama_guard_4(self, mock_vllm_config_llama_guard_4, rng,
127
+ mesh):
128
+ """Tests correct initialization and parameter detection."""
129
+ model = LlamaGuard4ForCausalLM(mock_vllm_config_llama_guard_4, rng,
130
+ mesh)
131
+
132
+ # Check model name is correctly set in the config
133
+ assert "llama-guard-4" in model.vllm_config.model_config.model.lower()
134
+
135
+ assert model.hidden_size == 128
136
+
137
+ def test_create_model_with_random_weights(self,
138
+ mock_vllm_config_llama_guard_4,
139
+ rng, mesh):
140
+ """
141
+ Tests that random weight initialization creates concrete, non-zero-variance arrays.
142
+ """
143
+ with jax.set_mesh(mesh):
144
+ model = LlamaGuard4ForCausalLM(
145
+ vllm_config=mock_vllm_config_llama_guard_4,
146
+ rng=rng,
147
+ mesh=mesh,
148
+ force_random_weights=True)
149
+
150
+ embedding_weight = model.embedder.input_embedding_table_VD.value
151
+ attention_q_kernel = model.layers[0].attn.kernel_q_proj_DNH.value
152
+ final_norm_scale = model.final_norm.scale.value
153
+
154
+ assert isinstance(embedding_weight, jax.Array)
155
+ assert isinstance(attention_q_kernel, jax.Array)
156
+ assert isinstance(final_norm_scale, jax.Array)
157
+
158
+ assert jnp.std(embedding_weight) > 0
159
+ assert jnp.std(attention_q_kernel) > 0
160
+
161
+ assert jnp.all(final_norm_scale == 1.0)
162
+
163
+ @patch("tpu_inference.models.jax.llama_guard_4.LlamaGuard4WeightLoader")
164
+ def test_load_weights_called_correctly(self, mock_loader_cls, rng, mesh):
165
+ """Tests that the weight loader is called correctly for checkpoint loading."""
166
+ vllm_config = MockVllmConfig(model_name="llama-guard-4-test",
167
+ random_weights=False)
168
+ model = LlamaGuard4ForCausalLM(vllm_config, rng, mesh)
169
+
170
+ mock_loader_instance = MagicMock()
171
+ mock_loader_cls.return_value = mock_loader_instance
172
+ model.load_weights(rng)
173
+
174
+ mock_loader_cls.assert_called_once_with(vllm_config=vllm_config,
175
+ hidden_size=128,
176
+ attn_heads=4,
177
+ num_key_value_heads=2,
178
+ attn_head_dim=32)
179
+ mock_loader_instance.load_weights.assert_called_once_with(model)
180
+
181
+
182
+ class TestLlamaGuard4WeightLoader:
183
+ """Tests for the LlamaGuard4WeightLoader class."""
184
+
185
+ @pytest.fixture
186
+ def weight_loader(self):
187
+ return LlamaGuard4WeightLoader(
188
+ vllm_config=MockVllmConfig("test-model"),
189
+ hidden_size=5120,
190
+ attn_heads=40,
191
+ num_key_value_heads=8,
192
+ attn_head_dim=128)
193
+
194
+ @pytest.mark.parametrize("hf_key, expected", [
195
+ ("language_model.model.layers.15.self_attn.q_proj.weight",
196
+ "layers.15.attn.kernel_q_proj_DNH"),
197
+ ("language_model.model.layers.0.feed_forward.gate_proj.weight",
198
+ "layers.0.custom_module.kernel_gating_DF"),
199
+ ("language_model.model.embed_tokens.weight",
200
+ "embedder.input_embedding_table_VD"),
201
+ ("language_model.model.norm.weight", "final_norm.scale"),
202
+ ("language_model.lm_head.weight", "lm_head.input_embedding_table_DV"),
203
+ ("unmapped.key.name", "unmapped.key.name"),
204
+ ])
205
+ def test_map_loaded_to_standardized_name(self, weight_loader, hf_key,
206
+ expected):
207
+ """Tests the mapping from HuggingFace key names to internal names."""
208
+ assert weight_loader.map_loaded_to_standardized_name(
209
+ hf_key) == expected
210
+
211
+ def test_load_weights_transformation(self, weight_loader, rng, mesh):
212
+ """Tests that weights are correctly reshaped, transposed, and loaded."""
213
+ vllm_config = MockVllmConfig(model_name="llama-guard-4-small-test",
214
+ random_weights=False)
215
+
216
+ model = LlamaGuard4ForCausalLM(vllm_config, rng, mesh)
217
+
218
+ hidden_size = 5120
219
+ vocab_size = 202048
220
+
221
+ original_weight = jnp.ones((vocab_size, hidden_size))
222
+ dummy_weights = [
223
+ ("language_model.model.embed_tokens.weight", original_weight),
224
+ ]
225
+ weight_loader.names_and_weights_generator = dummy_weights
226
+
227
+ # Mock get_param to return a mock param with the target shape
228
+ mock_param = MockParamLlamaGuard4(shape=(vocab_size, hidden_size))
229
+
230
+ with patch("tpu_inference.models.jax.llama_guard_4.get_param", return_value=mock_param), \
231
+ patch("tpu_inference.models.jax.llama_guard_4.shard_put", return_value=jnp.ones(mock_param.value.shape)) as mock_shard_put:
232
+
233
+ weight_loader.load_weights(model)
234
+
235
+ # Assert that shard_put was called with the correctly transposed weight
236
+ mock_shard_put.assert_called_once()
237
+
238
+ # Get the actual array passed to shard_put
239
+ called_with_weight = mock_shard_put.call_args[0][0]
240
+
241
+ # Check if the shape of the array passed to shard_put matches the model's expected shape.
242
+ assert called_with_weight.shape == mock_param.value.shape
@@ -0,0 +1,172 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pytest
21
+ from flax import nnx
22
+ from flax.typing import PRNGKey
23
+ from jax.sharding import Mesh
24
+ from vllm.config import ModelConfig
25
+
26
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
28
+ from tpu_inference.runner.kv_cache import create_kv_caches
29
+
30
+
31
+ class MockVllmConfig:
32
+ """A mock VllmConfig sufficient for testing the Qwen2 model."""
33
+
34
+ def __init__(self, model: str, kv_cache_dtype: str):
35
+ self.model_config = ModelConfig(model)
36
+ self.model_config.dtype = jnp.bfloat16
37
+ self.load_config = MagicMock()
38
+ self.load_config.download_dir = None
39
+ self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
40
+
41
+
42
+ @pytest.fixture(scope="module")
43
+ def mesh():
44
+ """
45
+ Creates a mesh with 1 device.
46
+ """
47
+ if not jax.devices():
48
+ pytest.skip("No JAX devices available for mesh creation.")
49
+
50
+ devices = np.array(jax.local_devices()[:1])
51
+ num_devices = len(devices)
52
+ assert num_devices == 1
53
+ device_mesh = devices.reshape((num_devices, 1, 1, 1))
54
+
55
+ with Mesh(device_mesh,
56
+ axis_names=('data', 'attn_dp', 'expert', 'model')) as m:
57
+ yield m
58
+
59
+
60
+ @pytest.fixture
61
+ def mock_model_inputs():
62
+ num_tokens = 8
63
+ num_reqs = 1
64
+ max_num_blocks_per_req = 4
65
+ input_ids = jnp.ones((num_tokens, ), dtype=jnp.int32)
66
+ positions = jnp.ones((num_tokens, ), dtype=jnp.int32)
67
+ block_tables = jnp.zeros((num_reqs, max_num_blocks_per_req),
68
+ dtype=jnp.int32).reshape(-1)
69
+ seq_lens = jnp.ones((num_reqs, ), dtype=jnp.int32)
70
+ query_start_loc = jnp.ones((num_reqs + 1, ), dtype=jnp.int32)
71
+ request_distribution = jnp.array([0, 0, 0], dtype=jnp.int32)
72
+
73
+ attention_metadata = AttentionMetadata(
74
+ input_positions=positions,
75
+ block_tables=block_tables,
76
+ seq_lens=seq_lens,
77
+ query_start_loc=query_start_loc,
78
+ request_distribution=request_distribution,
79
+ )
80
+ indices_do_sample = jnp.ones((num_reqs, ), dtype=jnp.int32)
81
+
82
+ return (input_ids, attention_metadata, indices_do_sample)
83
+
84
+
85
+ @pytest.fixture
86
+ def rng() -> PRNGKey:
87
+ """Provides a reusable JAX PRNGKey."""
88
+ return jax.random.PRNGKey(42)
89
+
90
+
91
+ class TestQwen2ForCausalLM:
92
+ """Tests for the main Qwen2ForCausalLM model class."""
93
+
94
+ @pytest.mark.parametrize("mock_vllm_config", [
95
+ MockVllmConfig("Qwen/Qwen2.5-1.5B", "auto"),
96
+ MockVllmConfig("Qwen/Qwen2.5-1.5B", "fp8")
97
+ ])
98
+ def test_qwen25_1_5b(self, mock_vllm_config, rng, mesh, mock_model_inputs):
99
+ """Tests model init and model forward for the 8B model variant."""
100
+
101
+ # Test model init
102
+ model = Qwen2ForCausalLM(mock_vllm_config, rng, mesh)
103
+ assert "1.5b" in model.vllm_config.model_config.model.lower()
104
+
105
+ model_config = mock_vllm_config.model_config
106
+ hf_config = model_config.hf_config
107
+
108
+ assert model.mesh.shape == {
109
+ "data": 1,
110
+ "attn_dp": 1,
111
+ "expert": 1,
112
+ "model": 1
113
+ }
114
+
115
+ layers = model.model.layers
116
+ assert len(layers) == hf_config.num_hidden_layers
117
+ assert isinstance(model.rng, nnx.Rngs)
118
+ assert model.model.lm_head == model.model.embed.embedding
119
+
120
+ attn = layers[0].self_attn
121
+ hidden_size = hf_config.hidden_size
122
+ num_heads = hf_config.num_attention_heads
123
+ num_kv_heads = hf_config.num_key_value_heads
124
+ rope_theta = hf_config.rope_theta
125
+ original_head_dim = hidden_size // num_heads
126
+ head_dim = 128
127
+ intermediate_size = hf_config.intermediate_size
128
+
129
+ assert attn.hidden_size == hidden_size
130
+ assert attn.num_heads == num_heads
131
+ assert attn.num_kv_heads == num_kv_heads
132
+ assert attn.rope_theta == rope_theta
133
+ assert attn.head_dim_original == original_head_dim
134
+ assert attn.head_dim == head_dim
135
+ assert attn.q_proj.kernel.shape == (hidden_size, num_heads, head_dim)
136
+ assert attn.k_proj.kernel.shape == (hidden_size, num_kv_heads,
137
+ head_dim)
138
+ assert attn.v_proj.kernel.shape == (hidden_size, num_kv_heads,
139
+ head_dim)
140
+ assert attn.o_proj.kernel.shape == (num_heads, head_dim, hidden_size)
141
+
142
+ mlp = layers[0].mlp
143
+ assert mlp.gate_proj.kernel.shape == (hidden_size, intermediate_size)
144
+ assert mlp.up_proj.kernel.shape == (hidden_size, intermediate_size)
145
+ assert mlp.down_proj.kernel.shape == (intermediate_size, hidden_size)
146
+
147
+ # Test model load
148
+ model.load_weights(rng)
149
+
150
+ # Test model forward
151
+ kv_caches = create_kv_caches(
152
+ num_blocks=4,
153
+ block_size=32,
154
+ num_kv_heads=num_kv_heads,
155
+ head_size=head_dim,
156
+ mesh=mesh,
157
+ layer_names=["layer"] * hf_config.num_hidden_layers,
158
+ cache_dtype=jnp.float8_e4m3fn
159
+ if mock_vllm_config.cache_config.cache_dtype == "fp8" else
160
+ jnp.bfloat16)
161
+ # 1 seq with 16 tokens
162
+ input_ids, attention_metadata, indices_do_sample = mock_model_inputs
163
+ kv_caches, hidden_states, aux_hidden_states = model(
164
+ kv_caches, input_ids, attention_metadata)
165
+ assert hidden_states.shape == (8, hidden_size)
166
+ assert len(aux_hidden_states) == 0
167
+
168
+ hidden_states = hidden_states[indices_do_sample]
169
+ assert hidden_states.shape == (1, hidden_size)
170
+
171
+ logits = model.compute_logits(hidden_states)
172
+ assert logits.shape == (1, hf_config.vocab_size)