tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -11,9 +11,9 @@ from jax.sharding import Mesh, NamedSharding
11
11
  from jax.sharding import PartitionSpec as P
12
12
  from qwix._src.providers import ptq
13
13
 
14
- import tpu_inference.models.jax.utils.quantization.quantization_utils as quantize_qwix # noqa: E402
14
+ import tpu_inference.models.jax.utils.qwix.qwix_utils as quantize_qwix # noqa: E402
15
15
  from tpu_inference.models.common.model_loader import apply_qwix_quantization
16
- from tpu_inference.models.jax.utils.quantization.quantization_utils import (
16
+ from tpu_inference.models.jax.utils.qwix.qwix_utils import (
17
17
  DEFAULT_MAX_NUM_BLOCKS_PER_REQ, DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS,
18
18
  DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS)
19
19
 
@@ -29,8 +29,7 @@ module_mocks = {
29
29
  'vllm.config': MagicMock(),
30
30
  'tpu_inference': MagicMock(),
31
31
  'tpu_inference.logger': MagicMock(init_logger=lambda name: MagicMock()),
32
- 'tpu_inference.models.jax.utils.quantization.quantization_utils':
33
- MagicMock(),
32
+ 'tpu_inference.models.jax.utils.qwix.qwix_utils': MagicMock(),
34
33
  }
35
34
 
36
35
 
@@ -112,6 +111,8 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
112
111
  self.mesh = Mesh(jax.devices(), ('model', ))
113
112
  self.rng = jax.random.PRNGKey(0)
114
113
  self.model = SimpleModel(rngs=nnx.Rngs(0))
114
+ self.model.vllm_config = MagicMock()
115
+ self.model.vllm_config.model_config.use_mla = False
115
116
 
116
117
  self.qwix_config = [
117
118
  {
@@ -131,18 +132,19 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
131
132
  """Test that qwix.quantize_model is called with the correct arguments."""
132
133
  quantized_model_mock = MagicMock(spec=nnx.Module)
133
134
  mock_quantize_model.return_value = quantized_model_mock
135
+ self.model.vllm_config.sharding_config.total_dp_size = 1
134
136
 
135
137
  with patch(
136
- "tpu_inference.models.jax.utils.quantization.quantization_utils.init_logger",
138
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.init_logger",
137
139
  return_value=MagicMock()
138
140
  ), patch(
139
141
  "tpu_inference.utils.hbm_usage_gb",
140
142
  return_value=[(0.0, 0.0), (0.0, 0.0)]
141
143
  ), patch(
142
- "tpu_inference.models.jax.utils.quantization.quantization_utils.create_kv_caches",
144
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.create_kv_caches",
143
145
  return_value=self.mock_kv_caches
144
146
  ), patch(
145
- "tpu_inference.models.jax.utils.quantization.quantization_utils.quantization_config_file_path_to_dict",
147
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.quantization_config_file_path_to_dict",
146
148
  return_value=self.qwix_config):
147
149
  returned_model = quantize_qwix.qwix_quantize_nnx_model(
148
150
  model=self.model,
@@ -317,10 +319,9 @@ class TestApplyQwixQuantizationLogic(unittest.TestCase):
317
319
  self.assertIs(result2, self.mock_model)
318
320
 
319
321
  @patch(
320
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.qwix_quantize_nnx_model'
322
+ 'tpu_inference.models.jax.utils.qwix.qwix_utils.qwix_quantize_nnx_model'
321
323
  )
322
- @patch(
323
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
324
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.utils')
324
325
  def test_apply_to_abstract_model(self, mock_utils, mock_quantize_func):
325
326
  """Test quantization is correctly applied to an abstract model factory."""
326
327
  mock_utils.get_padded_num_heads.return_value = 8
@@ -357,10 +358,9 @@ class TestApplyQwixQuantizationLogic(unittest.TestCase):
357
358
  self.assertIs(result_model, quantized_model)
358
359
 
359
360
  @patch(
360
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.qwix_quantize_nnx_model'
361
+ 'tpu_inference.models.jax.utils.qwix.qwix_utils.qwix_quantize_nnx_model'
361
362
  )
362
- @patch(
363
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
363
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.utils')
364
364
  def test_apply_to_abstract_model_with_initialize_cache(
365
365
  self, mock_utils, mock_quantize_func):
366
366
  """Test abstract model quantization with 'initialize_cache' method."""
@@ -461,15 +461,13 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
461
461
  # Mock model structure
462
462
  self.model = MagicMock(spec=['weight_loader', 'initialize_cache'])
463
463
  self.model.weight_loader = MagicMock(
464
- spec=['scale_dtype', 'scale_shap_map_for_random_weight_loading'])
464
+ spec=['scale_dtype', 'scale_shape_map_for_random_weight_loading'])
465
465
  self.model.weight_loader.scale_dtype = jnp.float16
466
- self.model.weight_loader.scale_shap_map_for_random_weight_loading = {}
466
+ self.model.weight_loader.scale_shape_map_for_random_weight_loading = {}
467
467
 
468
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
468
469
  @patch(
469
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
470
- )
471
- @patch(
472
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.get_random_sharded_array'
470
+ 'tpu_inference.models.jax.utils.qwix.qwix_utils.get_random_sharded_array'
473
471
  )
474
472
  def test_successful_initialization(self, mock_get_random_array,
475
473
  mock_iter_graph):
@@ -482,6 +480,10 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
482
480
  mock_random_array = jax.numpy.ones(1)
483
481
  mock_get_random_array.return_value = mock_random_array
484
482
 
483
+ self.model.weight_loader.scale_shape_map_for_random_weight_loading = {
484
+ 'attention.wq': (1, 1)
485
+ }
486
+
485
487
  mock_iter_graph.return_value = [
486
488
  (('layers', '0', 'attention', 'wq', 'kernel'), mock_weight_param),
487
489
  (('layers', '0', 'attention', 'wq', 'array', 'scale'),
@@ -509,9 +511,7 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
509
511
  quantize_qwix.load_random_weights_into_qwix_abstract_model(
510
512
  self.rng, self.model, self.mesh, invalid_config)
511
513
 
512
- @patch(
513
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
514
- )
514
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
515
515
  def test_param_shape_setting_no_scale_map(self, mock_iter_graph):
516
516
  """Test correct scale shape calculation when not in the map."""
517
517
  old_weight_param_val = jnp.empty((128, 64))
@@ -525,26 +525,11 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
525
525
  mock_scale_var),
526
526
  ]
527
527
 
528
- quantize_qwix.load_random_weights_into_qwix_abstract_model(
529
- self.rng, self.model, self.mesh, self.quantization_config)
530
-
531
- new_weight_param_val = mock_weight_param.value
532
- new_scale_var_val = mock_scale_var.value
533
-
534
- expected_scale_shape = (128 // 64, 64 // 64)
535
- actual_scale_shape = new_scale_var_val.shape
536
-
537
- expected_weight_shape = (128, 64)
538
- actual_weight_shape = new_weight_param_val.shape
539
-
540
- self.assertEqual(expected_scale_shape, actual_scale_shape)
541
- self.assertEqual(expected_weight_shape, actual_weight_shape)
542
- self.assertNotEqual(old_scale_var_val.shape, new_scale_var_val.shape)
543
- assert jnp.not_equal(old_weight_param_val, new_weight_param_val).all()
528
+ with self.assertRaises(ValueError):
529
+ quantize_qwix.load_random_weights_into_qwix_abstract_model(
530
+ self.rng, self.model, self.mesh, self.quantization_config)
544
531
 
545
- @patch(
546
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
547
- )
532
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
548
533
  def test_param_shape_setting_with_scale_map(self, mock_iter_graph):
549
534
  """Test correct scale shape calculation when in the map."""
550
535
  old_weight_param_val = jnp.empty((128, 64))
@@ -554,8 +539,8 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
554
539
 
555
540
  expected_scale_shape = (55, 34)
556
541
 
557
- self.model.weight_loader.scale_shap_map_for_random_weight_loading = {
558
- 'wq': expected_scale_shape
542
+ self.model.weight_loader.scale_shape_map_for_random_weight_loading = {
543
+ 'attention.wq': expected_scale_shape
559
544
  }
560
545
 
561
546
  mock_iter_graph.return_value = [
@@ -604,9 +589,7 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
604
589
  mock_randint.assert_not_called()
605
590
  mock_normal.assert_called_once()
606
591
 
607
- @patch(
608
- "tpu_inference.models.jax.utils.quantization.quantization_utils.logger.warning"
609
- )
592
+ @patch("tpu_inference.models.jax.utils.qwix.qwix_utils.logger.warning")
610
593
  @patch("jax.make_array_from_callback")
611
594
  def test_get_random_sharded_array_sharding_fallback(
612
595
  self, mock_make_array, mock_logger_warning):
@@ -648,7 +631,7 @@ class TestManualQwixQuantization(unittest.TestCase):
648
631
  self.calibration_method = 'max'
649
632
 
650
633
  @patch(
651
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.ptq.create_quantized_param'
634
+ 'tpu_inference.models.jax.utils.qwix.qwix_utils.ptq.create_quantized_param'
652
635
  )
653
636
  def test_manually_quantize_qwix_weight(self, mock_create_param):
654
637
  """Test that manually_quantize_qwix_weight calls ptq.create_quantized_param correctly."""
@@ -672,9 +655,7 @@ class TestManualQwixQuantization(unittest.TestCase):
672
655
  self.assertEqual(passed_how_to_quantize.calibration_method,
673
656
  self.calibration_method)
674
657
 
675
- @patch(
676
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.ptq.quantize_act'
677
- )
658
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.ptq.quantize_act')
678
659
  @patch('qwix.pallas.get_current_rule')
679
660
  def test_manually_quantize_qwix_activation(self, mock_get_rule,
680
661
  mock_quantize_act):
@@ -832,5 +813,157 @@ class TestGetQuantDtypeFromQwixConfig(unittest.TestCase):
832
813
  self.assertIsNone(quant_dtype)
833
814
 
834
815
 
816
+ class TestGetDefaultQwixQuantizationConfig(unittest.TestCase):
817
+ """Tests for the get_default_qwix_quantization_config function."""
818
+
819
+ def setUp(self):
820
+ # Mocking the default configs that the function expects to find in the module
821
+ self.mock_deepseek_config = {
822
+ "qwix": {
823
+ "rules": [{
824
+ "module_path": ".*",
825
+ "tile_size": 0
826
+ }]
827
+ }
828
+ }
829
+ self.mock_llama_config = {"qwix": {"rules": [{"name": "llama_rule"}]}}
830
+ self.mock_gpt_oss_config = {"qwix": {"rules": [{"name": "gpt_rule"}]}}
831
+
832
+ # Patch the constants in the module where the function resides
833
+ self.patchers = [
834
+ patch(
835
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_DEEPSEEK_FP8_CONFIG",
836
+ self.mock_deepseek_config),
837
+ patch(
838
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_LLAMA4_FP8_CONFIG",
839
+ self.mock_llama_config),
840
+ patch(
841
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_GPT_OSS_FP4_CONFIG",
842
+ self.mock_gpt_oss_config),
843
+ patch("tpu_inference.models.jax.utils.qwix.qwix_utils.logger",
844
+ MagicMock())
845
+ ]
846
+ for p in self.patchers:
847
+ p.start()
848
+
849
+ def tearDown(self):
850
+ for p in self.patchers:
851
+ p.stop()
852
+
853
+ def test_skip_quantization_returns_none(self):
854
+ """Test that skip_quantization=True returns None immediately."""
855
+ result = quantize_qwix.get_default_qwix_quantization_config(
856
+ MagicMock(), True)
857
+ self.assertIsNone(result)
858
+
859
+ def test_unsupported_model_returns_none(self):
860
+ """Test that an unknown model type returns None."""
861
+ hf_config = MagicMock()
862
+ hf_config.model_type = "unknown_model"
863
+ result = quantize_qwix.get_default_qwix_quantization_config(
864
+ hf_config, False)
865
+ self.assertIsNone(result)
866
+
867
+ def test_deepseek_v3_success(self):
868
+ """Test DeepSeek V3 config with valid weight_block_size."""
869
+ hf_config = MagicMock()
870
+ hf_config.model_type = "DeepSeek_V3"
871
+ hf_config.quantization_config = {
872
+ "quant_method": "fp8",
873
+ "weight_block_size": [1, 128]
874
+ }
875
+
876
+ result = quantize_qwix.get_default_qwix_quantization_config(
877
+ hf_config, False)
878
+
879
+ # Check if tile_size was updated from 0 to 128
880
+ self.assertEqual(result["qwix"]["rules"][0]["tile_size"], 128)
881
+ # Ensure it's a deep copy (original mock shouldn't change)
882
+ self.assertEqual(
883
+ self.mock_deepseek_config["qwix"]["rules"][0]["tile_size"], 0)
884
+
885
+ def test_deepseek_v3_invalid_block_size(self):
886
+ """Test DeepSeek V3 raises ValueError on invalid block size format."""
887
+ hf_config = MagicMock()
888
+ hf_config.model_type = "deepseek_v3"
889
+ hf_config.quantization_config = {
890
+ "quant_method": "fp8",
891
+ "weight_block_size": [128]
892
+ }
893
+
894
+ with self.assertRaisesRegex(ValueError, "Invalid weight_block_size"):
895
+ quantize_qwix.get_default_qwix_quantization_config(
896
+ hf_config, False)
897
+
898
+ def test_deepseek_v3_invalid_block_size_2d_subchannel(self):
899
+ """Test DeepSeek V3 raises ValueError on invalid block size format."""
900
+ hf_config = MagicMock()
901
+ hf_config.model_type = "deepseek_v3"
902
+ hf_config.quantization_config = {
903
+ "quant_method": "fp8",
904
+ "weight_block_size": [512, 512]
905
+ }
906
+
907
+ with self.assertRaisesRegex(AssertionError,
908
+ "Expected first dimension to be 1"):
909
+ quantize_qwix.get_default_qwix_quantization_config(
910
+ hf_config, False)
911
+
912
+ def test_deepseek_v3_no_weight_block_size(self):
913
+ """Test DeepSeek V3 config with valid weight_block_size."""
914
+ hf_config = MagicMock()
915
+ hf_config.model_type = "DeepSeek_V3"
916
+ hf_config.quantization_config = {
917
+ "quant_method": "fp8",
918
+ }
919
+
920
+ with self.assertRaisesRegex(
921
+ AssertionError,
922
+ "Expected weight_block_size in quantization_config"):
923
+
924
+ quantize_qwix.get_default_qwix_quantization_config(
925
+ hf_config, False)
926
+
927
+ def test_deepseek_v3_tile_size_assertion(self):
928
+ """Test DeepSeek V3 raises AssertionError if tile_size is <= 1."""
929
+ hf_config = MagicMock()
930
+ hf_config.model_type = "deepseek_v3"
931
+ hf_config.quantization_config = {
932
+ "quant_method": "fp8",
933
+ "weight_block_size": [1, 1]
934
+ }
935
+
936
+ with self.assertRaises(AssertionError):
937
+ quantize_qwix.get_default_qwix_quantization_config(
938
+ hf_config, False)
939
+
940
+ def test_llama4_success(self):
941
+ """Test Llama 4 default config path."""
942
+ hf_config = MagicMock()
943
+ hf_config.model_type = "llama4"
944
+ hf_config.quantization_config = {"quant_method": "compressed-tensors"}
945
+
946
+ result = quantize_qwix.get_default_qwix_quantization_config(
947
+ hf_config, False)
948
+ self.assertEqual(result, self.mock_llama_config)
949
+
950
+ def test_gpt_oss_success(self):
951
+ """Test GPT-OSS default config path."""
952
+ hf_config = MagicMock()
953
+ hf_config.model_type = "gpt_oss"
954
+ hf_config.quantization_config = {"quant_method": "mxfp4"}
955
+
956
+ result = quantize_qwix.get_default_qwix_quantization_config(
957
+ hf_config, False)
958
+ self.assertEqual(result, self.mock_gpt_oss_config)
959
+
960
+ def test_missing_attributes_handled(self):
961
+ """Test that function handles hf_config objects missing model_type safely."""
962
+ hf_config = object() # No attributes
963
+ result = quantize_qwix.get_default_qwix_quantization_config(
964
+ hf_config, False)
965
+ self.assertIsNone(result)
966
+
967
+
835
968
  if __name__ == '__main__':
836
969
  unittest.main()
@@ -0,0 +1,93 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ from jax import numpy as jnp
17
+ from jax._src import test_util as jtu
18
+ from jax.sharding import Mesh
19
+
20
+ from tpu_inference.layers.jax.rope import (DeepseekScalingRotaryEmbedding,
21
+ RotaryEmbedding)
22
+
23
+
24
+ class RotaryEmbeddingTest(jtu.JaxTestCase):
25
+
26
+ def test_apply_rope(self):
27
+ head_dim = 2
28
+ rope_theta = 10000
29
+ original_max_position_embeddings = 2
30
+ rope = RotaryEmbedding(
31
+ rotary_dim=head_dim,
32
+ rope_theta=rope_theta,
33
+ original_max_position_embeddings=original_max_position_embeddings,
34
+ dtype=jnp.float32)
35
+ rope.initialize_cache()
36
+ self.assertTrue(
37
+ rope.sin_cos_cache.shape == (original_max_position_embeddings,
38
+ head_dim))
39
+ expected_sin_cos = jnp.array([[1, 0], [0.5403023, 0.841471]],
40
+ dtype=jnp.float32)
41
+ self.assertArraysAllClose(rope.sin_cos_cache, expected_sin_cos)
42
+
43
+ num_tokens = 2
44
+ num_heads = 1
45
+ positions = jnp.arange(num_tokens)
46
+ x = jnp.ones((num_tokens, num_heads, head_dim))
47
+ x_rope = rope.apply_rope(positions, x)
48
+ expected_x_rope = jnp.array([[[1, 1]], [[-0.30116874, 1.3817732]]],
49
+ dtype=jnp.float32)
50
+ self.assertTrue(x_rope.shape == x.shape)
51
+ self.assertArraysAllClose(x_rope, expected_x_rope)
52
+
53
+
54
+ class DeepseekScalingRotaryEmbeddingTest(jtu.JaxTestCase):
55
+
56
+ def test_apply_rope(self):
57
+ head_dim = 2
58
+ rope_theta = 10000
59
+ original_max_position_embeddings = 1
60
+ scaling_factor = 2
61
+ devices = jax.devices()
62
+ mesh = Mesh(devices, ('data', ))
63
+
64
+ rope = DeepseekScalingRotaryEmbedding(
65
+ rotary_dim=head_dim,
66
+ rope_theta=rope_theta,
67
+ original_max_position_embeddings=original_max_position_embeddings,
68
+ scaling_factor=scaling_factor,
69
+ dtype=jnp.float32)
70
+ rope.initialize_cache(mesh)
71
+ expected_padded_dim = 128
72
+ self.assertTrue(
73
+ rope.sin_cos_cache.shape == (scaling_factor *
74
+ original_max_position_embeddings,
75
+ expected_padded_dim))
76
+
77
+ valid_cache_slice = rope.sin_cos_cache[:, :head_dim]
78
+
79
+ expected_sin_cos = jnp.array([[1.0693147, 0], [0.5777532, 0.8997973]],
80
+ dtype=jnp.float32)
81
+
82
+ self.assertArraysAllClose(valid_cache_slice, expected_sin_cos)
83
+
84
+ num_tokens = 2
85
+ num_heads = 1
86
+ positions = jnp.arange(num_tokens)
87
+ x = jnp.ones((num_tokens, num_heads, head_dim))
88
+ x_rope = rope.apply_rope(positions, x)
89
+ expected_x_rope = jnp.array(
90
+ [[[1.0693147, 1.0693147]], [[-0.32204413, 1.4775505]]],
91
+ dtype=jnp.float32)
92
+ self.assertTrue(x_rope.shape == x.shape)
93
+ self.assertArraysAllClose(x_rope, expected_x_rope)
@@ -0,0 +1,159 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from unittest.mock import MagicMock
17
+
18
+ import jax
19
+
20
+ from tpu_inference.layers.common.sharding import (Sharding, ShardingConfig,
21
+ ShardingRulesConfig,
22
+ ShardingStrategy)
23
+
24
+
25
+ class TestSharding(unittest.TestCase):
26
+ """Unit test suite for the sharding configuration logic."""
27
+
28
+ def setUp(self):
29
+ """Sets up the testing environment before each test."""
30
+
31
+ self.mock_devices = [MagicMock(coords=i) for i in range(8)]
32
+ self.original_jax_devices = jax.devices
33
+ jax.devices = lambda: self.mock_devices
34
+
35
+ def tearDown(self):
36
+ """Restores the original jax.devices function after tests."""
37
+ jax.devices = self.original_jax_devices
38
+
39
+ def test_sharding_strategy_init(self):
40
+ """Tests the initialization of the ShardingStrategy."""
41
+ strategy = ShardingStrategy(
42
+ tensor_parallelism=2,
43
+ expert_parallelism=4,
44
+ data_parallelism=1,
45
+ sequence_parallelism=1,
46
+ )
47
+ self.assertEqual(strategy.tensor_parallelism, 2)
48
+ self.assertEqual(strategy.expert_parallelism, 4)
49
+
50
+ def test_sharding_config_init(self):
51
+ """Tests the initialization of ShardingConfig."""
52
+ config = ShardingConfig()
53
+ self.assertIsInstance(config.prefill_rules, ShardingRulesConfig)
54
+ self.assertIsInstance(config.generate_rules, ShardingRulesConfig)
55
+
56
+ custom_rules = ShardingRulesConfig(activation_ffw_td=("model", None))
57
+ config_with_rules = ShardingConfig(prefill_rules=custom_rules)
58
+ self.assertEqual(config_with_rules.prefill_rules.activation_ffw_td,
59
+ ("model", None))
60
+
61
+ def test_apply_overrides(self):
62
+ """Tests the _apply_overrides method for valid and invalid keys."""
63
+ sharding = Sharding(
64
+ prefill_rules={},
65
+ generate_rules={},
66
+ )
67
+ config_obj = ShardingRulesConfig()
68
+
69
+ valid_overrides = {"activation_ffw_td": ("model", None)}
70
+ sharding._apply_overrides(config_obj, valid_overrides)
71
+ self.assertEqual(config_obj.activation_ffw_td, ("model", None))
72
+
73
+ invalid_overrides = {"non_existent_attribute": (None, "model")}
74
+ with self.assertRaises(AttributeError):
75
+ sharding._apply_overrides(config_obj, invalid_overrides)
76
+
77
+ def test_default_sharding_config(self):
78
+ """Tests that default sharding rules are created correctly."""
79
+ sharding = Sharding(
80
+ prefill_rules={},
81
+ generate_rules={},
82
+ )
83
+
84
+ sharding_cfg = sharding.get_sharding_cfg()
85
+ generate_rules = sharding_cfg.generate_rules
86
+
87
+ self.assertEqual(generate_rules.ffw_weight_df, (None, "model"))
88
+ self.assertEqual(generate_rules.moe_router_de, (None, "model"))
89
+ self.assertEqual(generate_rules.attn_q_weight_dnh,
90
+ (None, "model", None))
91
+
92
+ def test_sharding_init_with_overrides(self):
93
+ """Tests Sharding initialization with programmatic overrides."""
94
+ generate_overrides = {"logits_tv": ("data", "model")}
95
+
96
+ sharding = Sharding(
97
+ generate_rules=generate_overrides,
98
+ prefill_rules={},
99
+ )
100
+
101
+ sharding_cfg = sharding.get_sharding_cfg()
102
+ self.assertNotEqual(sharding_cfg.generate_rules.logits_tv,
103
+ (None, "model"))
104
+ self.assertEqual(sharding_cfg.generate_rules.logits_tv,
105
+ ("data", "model"))
106
+
107
+ def test_get_overrides_from_vllm_config(self):
108
+ """Tests fetching sharding overrides from a mock VllmConfig."""
109
+
110
+ mock_vllm_config_prefill = MagicMock()
111
+ mock_vllm_config_prefill.additional_config = {
112
+ "sharding": {
113
+ "logical_rules": {
114
+ "all": {
115
+ "norm_scale": ("model", )
116
+ },
117
+ "prefill": {
118
+ "activation_ffw_td": ("data", "model")
119
+ },
120
+ }
121
+ }
122
+ }
123
+ sharding_prefill = Sharding(
124
+ vllm_config=mock_vllm_config_prefill,
125
+ prefill_rules={},
126
+ generate_rules={},
127
+ )
128
+ prefill_overrides = sharding_prefill._get_overrides("prefill")
129
+
130
+ self.assertEqual(prefill_overrides["norm_scale"], ("model", ))
131
+ self.assertEqual(prefill_overrides["activation_ffw_td"],
132
+ ("data", "model"))
133
+
134
+ mock_vllm_config_generate = MagicMock()
135
+ mock_vllm_config_generate.additional_config = {
136
+ "sharding": {
137
+ "logical_rules": {
138
+ "all": {
139
+ "norm_scale": ("model", )
140
+ },
141
+ "prefill": {
142
+ "activation_ffw_td": ("data", "model")
143
+ },
144
+ }
145
+ }
146
+ }
147
+ sharding_generate = Sharding(
148
+ vllm_config=mock_vllm_config_generate,
149
+ prefill_rules={},
150
+ generate_rules={},
151
+ )
152
+ generate_overrides = sharding_generate._get_overrides("generate")
153
+
154
+ self.assertEqual(generate_overrides["norm_scale"], ("model", ))
155
+ self.assertNotIn("activation_ffw_td", generate_overrides)
156
+
157
+
158
+ if __name__ == "__main__":
159
+ unittest.main()