tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -11,9 +11,9 @@ from jax.sharding import Mesh, NamedSharding
11
11
  from jax.sharding import PartitionSpec as P
12
12
  from qwix._src.providers import ptq
13
13
 
14
- import tpu_inference.models.jax.utils.quantization.quantization_utils as quantize_qwix # noqa: E402
14
+ import tpu_inference.models.jax.utils.qwix.qwix_utils as quantize_qwix # noqa: E402
15
15
  from tpu_inference.models.common.model_loader import apply_qwix_quantization
16
- from tpu_inference.models.jax.utils.quantization.quantization_utils import (
16
+ from tpu_inference.models.jax.utils.qwix.qwix_utils import (
17
17
  DEFAULT_MAX_NUM_BLOCKS_PER_REQ, DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS,
18
18
  DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS)
19
19
 
@@ -29,8 +29,7 @@ module_mocks = {
29
29
  'vllm.config': MagicMock(),
30
30
  'tpu_inference': MagicMock(),
31
31
  'tpu_inference.logger': MagicMock(init_logger=lambda name: MagicMock()),
32
- 'tpu_inference.models.jax.utils.quantization.quantization_utils':
33
- MagicMock(),
32
+ 'tpu_inference.models.jax.utils.qwix.qwix_utils': MagicMock(),
34
33
  }
35
34
 
36
35
 
@@ -136,16 +135,16 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
136
135
  self.model.vllm_config.sharding_config.total_dp_size = 1
137
136
 
138
137
  with patch(
139
- "tpu_inference.models.jax.utils.quantization.quantization_utils.init_logger",
138
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.init_logger",
140
139
  return_value=MagicMock()
141
140
  ), patch(
142
141
  "tpu_inference.utils.hbm_usage_gb",
143
142
  return_value=[(0.0, 0.0), (0.0, 0.0)]
144
143
  ), patch(
145
- "tpu_inference.models.jax.utils.quantization.quantization_utils.create_kv_caches",
144
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.create_kv_caches",
146
145
  return_value=self.mock_kv_caches
147
146
  ), patch(
148
- "tpu_inference.models.jax.utils.quantization.quantization_utils.quantization_config_file_path_to_dict",
147
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.quantization_config_file_path_to_dict",
149
148
  return_value=self.qwix_config):
150
149
  returned_model = quantize_qwix.qwix_quantize_nnx_model(
151
150
  model=self.model,
@@ -320,10 +319,9 @@ class TestApplyQwixQuantizationLogic(unittest.TestCase):
320
319
  self.assertIs(result2, self.mock_model)
321
320
 
322
321
  @patch(
323
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.qwix_quantize_nnx_model'
322
+ 'tpu_inference.models.jax.utils.qwix.qwix_utils.qwix_quantize_nnx_model'
324
323
  )
325
- @patch(
326
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
324
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.utils')
327
325
  def test_apply_to_abstract_model(self, mock_utils, mock_quantize_func):
328
326
  """Test quantization is correctly applied to an abstract model factory."""
329
327
  mock_utils.get_padded_num_heads.return_value = 8
@@ -360,10 +358,9 @@ class TestApplyQwixQuantizationLogic(unittest.TestCase):
360
358
  self.assertIs(result_model, quantized_model)
361
359
 
362
360
  @patch(
363
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.qwix_quantize_nnx_model'
361
+ 'tpu_inference.models.jax.utils.qwix.qwix_utils.qwix_quantize_nnx_model'
364
362
  )
365
- @patch(
366
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
363
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.utils')
367
364
  def test_apply_to_abstract_model_with_initialize_cache(
368
365
  self, mock_utils, mock_quantize_func):
369
366
  """Test abstract model quantization with 'initialize_cache' method."""
@@ -464,15 +461,13 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
464
461
  # Mock model structure
465
462
  self.model = MagicMock(spec=['weight_loader', 'initialize_cache'])
466
463
  self.model.weight_loader = MagicMock(
467
- spec=['scale_dtype', 'scale_shap_map_for_random_weight_loading'])
464
+ spec=['scale_dtype', 'scale_shape_map_for_random_weight_loading'])
468
465
  self.model.weight_loader.scale_dtype = jnp.float16
469
- self.model.weight_loader.scale_shap_map_for_random_weight_loading = {}
466
+ self.model.weight_loader.scale_shape_map_for_random_weight_loading = {}
470
467
 
468
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
471
469
  @patch(
472
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
473
- )
474
- @patch(
475
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.get_random_sharded_array'
470
+ 'tpu_inference.models.jax.utils.qwix.qwix_utils.get_random_sharded_array'
476
471
  )
477
472
  def test_successful_initialization(self, mock_get_random_array,
478
473
  mock_iter_graph):
@@ -485,6 +480,10 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
485
480
  mock_random_array = jax.numpy.ones(1)
486
481
  mock_get_random_array.return_value = mock_random_array
487
482
 
483
+ self.model.weight_loader.scale_shape_map_for_random_weight_loading = {
484
+ 'attention.wq': (1, 1)
485
+ }
486
+
488
487
  mock_iter_graph.return_value = [
489
488
  (('layers', '0', 'attention', 'wq', 'kernel'), mock_weight_param),
490
489
  (('layers', '0', 'attention', 'wq', 'array', 'scale'),
@@ -512,9 +511,7 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
512
511
  quantize_qwix.load_random_weights_into_qwix_abstract_model(
513
512
  self.rng, self.model, self.mesh, invalid_config)
514
513
 
515
- @patch(
516
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
517
- )
514
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
518
515
  def test_param_shape_setting_no_scale_map(self, mock_iter_graph):
519
516
  """Test correct scale shape calculation when not in the map."""
520
517
  old_weight_param_val = jnp.empty((128, 64))
@@ -528,26 +525,11 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
528
525
  mock_scale_var),
529
526
  ]
530
527
 
531
- quantize_qwix.load_random_weights_into_qwix_abstract_model(
532
- self.rng, self.model, self.mesh, self.quantization_config)
533
-
534
- new_weight_param_val = mock_weight_param.value
535
- new_scale_var_val = mock_scale_var.value
536
-
537
- expected_scale_shape = (128 // 64, 64 // 64)
538
- actual_scale_shape = new_scale_var_val.shape
539
-
540
- expected_weight_shape = (128, 64)
541
- actual_weight_shape = new_weight_param_val.shape
542
-
543
- self.assertEqual(expected_scale_shape, actual_scale_shape)
544
- self.assertEqual(expected_weight_shape, actual_weight_shape)
545
- self.assertNotEqual(old_scale_var_val.shape, new_scale_var_val.shape)
546
- assert jnp.not_equal(old_weight_param_val, new_weight_param_val).all()
528
+ with self.assertRaises(ValueError):
529
+ quantize_qwix.load_random_weights_into_qwix_abstract_model(
530
+ self.rng, self.model, self.mesh, self.quantization_config)
547
531
 
548
- @patch(
549
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
550
- )
532
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
551
533
  def test_param_shape_setting_with_scale_map(self, mock_iter_graph):
552
534
  """Test correct scale shape calculation when in the map."""
553
535
  old_weight_param_val = jnp.empty((128, 64))
@@ -557,8 +539,8 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
557
539
 
558
540
  expected_scale_shape = (55, 34)
559
541
 
560
- self.model.weight_loader.scale_shap_map_for_random_weight_loading = {
561
- 'wq': expected_scale_shape
542
+ self.model.weight_loader.scale_shape_map_for_random_weight_loading = {
543
+ 'attention.wq': expected_scale_shape
562
544
  }
563
545
 
564
546
  mock_iter_graph.return_value = [
@@ -607,9 +589,7 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
607
589
  mock_randint.assert_not_called()
608
590
  mock_normal.assert_called_once()
609
591
 
610
- @patch(
611
- "tpu_inference.models.jax.utils.quantization.quantization_utils.logger.warning"
612
- )
592
+ @patch("tpu_inference.models.jax.utils.qwix.qwix_utils.logger.warning")
613
593
  @patch("jax.make_array_from_callback")
614
594
  def test_get_random_sharded_array_sharding_fallback(
615
595
  self, mock_make_array, mock_logger_warning):
@@ -651,7 +631,7 @@ class TestManualQwixQuantization(unittest.TestCase):
651
631
  self.calibration_method = 'max'
652
632
 
653
633
  @patch(
654
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.ptq.create_quantized_param'
634
+ 'tpu_inference.models.jax.utils.qwix.qwix_utils.ptq.create_quantized_param'
655
635
  )
656
636
  def test_manually_quantize_qwix_weight(self, mock_create_param):
657
637
  """Test that manually_quantize_qwix_weight calls ptq.create_quantized_param correctly."""
@@ -675,9 +655,7 @@ class TestManualQwixQuantization(unittest.TestCase):
675
655
  self.assertEqual(passed_how_to_quantize.calibration_method,
676
656
  self.calibration_method)
677
657
 
678
- @patch(
679
- 'tpu_inference.models.jax.utils.quantization.quantization_utils.ptq.quantize_act'
680
- )
658
+ @patch('tpu_inference.models.jax.utils.qwix.qwix_utils.ptq.quantize_act')
681
659
  @patch('qwix.pallas.get_current_rule')
682
660
  def test_manually_quantize_qwix_activation(self, mock_get_rule,
683
661
  mock_quantize_act):
@@ -835,5 +813,157 @@ class TestGetQuantDtypeFromQwixConfig(unittest.TestCase):
835
813
  self.assertIsNone(quant_dtype)
836
814
 
837
815
 
816
+ class TestGetDefaultQwixQuantizationConfig(unittest.TestCase):
817
+ """Tests for the get_default_qwix_quantization_config function."""
818
+
819
+ def setUp(self):
820
+ # Mocking the default configs that the function expects to find in the module
821
+ self.mock_deepseek_config = {
822
+ "qwix": {
823
+ "rules": [{
824
+ "module_path": ".*",
825
+ "tile_size": 0
826
+ }]
827
+ }
828
+ }
829
+ self.mock_llama_config = {"qwix": {"rules": [{"name": "llama_rule"}]}}
830
+ self.mock_gpt_oss_config = {"qwix": {"rules": [{"name": "gpt_rule"}]}}
831
+
832
+ # Patch the constants in the module where the function resides
833
+ self.patchers = [
834
+ patch(
835
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG",
836
+ self.mock_deepseek_config),
837
+ patch(
838
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_LLAMA4_FP8_CONFIG",
839
+ self.mock_llama_config),
840
+ patch(
841
+ "tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_GPT_OSS_FP4_CONFIG",
842
+ self.mock_gpt_oss_config),
843
+ patch("tpu_inference.models.jax.utils.qwix.qwix_utils.logger",
844
+ MagicMock())
845
+ ]
846
+ for p in self.patchers:
847
+ p.start()
848
+
849
+ def tearDown(self):
850
+ for p in self.patchers:
851
+ p.stop()
852
+
853
+ def test_skip_quantization_returns_none(self):
854
+ """Test that skip_quantization=True returns None immediately."""
855
+ result = quantize_qwix.get_default_qwix_quantization_config(
856
+ MagicMock(), True)
857
+ self.assertIsNone(result)
858
+
859
+ def test_unsupported_model_returns_none(self):
860
+ """Test that an unknown model type returns None."""
861
+ hf_config = MagicMock()
862
+ hf_config.model_type = "unknown_model"
863
+ result = quantize_qwix.get_default_qwix_quantization_config(
864
+ hf_config, False)
865
+ self.assertIsNone(result)
866
+
867
+ def test_deepseek_v3_success(self):
868
+ """Test DeepSeek V3 config with valid weight_block_size."""
869
+ hf_config = MagicMock()
870
+ hf_config.model_type = "DeepSeek_V3"
871
+ hf_config.quantization_config = {
872
+ "quant_method": "fp8",
873
+ "weight_block_size": [1, 128]
874
+ }
875
+
876
+ result = quantize_qwix.get_default_qwix_quantization_config(
877
+ hf_config, False)
878
+
879
+ # Check if tile_size was updated from 0 to 128
880
+ self.assertEqual(result["qwix"]["rules"][0]["tile_size"], 128)
881
+ # Ensure it's a deep copy (original mock shouldn't change)
882
+ self.assertEqual(
883
+ self.mock_deepseek_config["qwix"]["rules"][0]["tile_size"], 0)
884
+
885
+ def test_deepseek_v3_invalid_block_size(self):
886
+ """Test DeepSeek V3 raises ValueError on invalid block size format."""
887
+ hf_config = MagicMock()
888
+ hf_config.model_type = "deepseek_v3"
889
+ hf_config.quantization_config = {
890
+ "quant_method": "fp8",
891
+ "weight_block_size": [128]
892
+ }
893
+
894
+ with self.assertRaisesRegex(ValueError, "Invalid weight_block_size"):
895
+ quantize_qwix.get_default_qwix_quantization_config(
896
+ hf_config, False)
897
+
898
+ def test_deepseek_v3_invalid_block_size_2d_subchannel(self):
899
+ """Test DeepSeek V3 raises ValueError on invalid block size format."""
900
+ hf_config = MagicMock()
901
+ hf_config.model_type = "deepseek_v3"
902
+ hf_config.quantization_config = {
903
+ "quant_method": "fp8",
904
+ "weight_block_size": [512, 512]
905
+ }
906
+
907
+ with self.assertRaisesRegex(AssertionError,
908
+ "Expected first dimension to be 1"):
909
+ quantize_qwix.get_default_qwix_quantization_config(
910
+ hf_config, False)
911
+
912
+ def test_deepseek_v3_no_weight_block_size(self):
913
+ """Test DeepSeek V3 config with valid weight_block_size."""
914
+ hf_config = MagicMock()
915
+ hf_config.model_type = "DeepSeek_V3"
916
+ hf_config.quantization_config = {
917
+ "quant_method": "fp8",
918
+ }
919
+
920
+ with self.assertRaisesRegex(
921
+ AssertionError,
922
+ "Expected weight_block_size in quantization_config"):
923
+
924
+ quantize_qwix.get_default_qwix_quantization_config(
925
+ hf_config, False)
926
+
927
+ def test_deepseek_v3_tile_size_assertion(self):
928
+ """Test DeepSeek V3 raises AssertionError if tile_size is <= 1."""
929
+ hf_config = MagicMock()
930
+ hf_config.model_type = "deepseek_v3"
931
+ hf_config.quantization_config = {
932
+ "quant_method": "fp8",
933
+ "weight_block_size": [1, 1]
934
+ }
935
+
936
+ with self.assertRaises(AssertionError):
937
+ quantize_qwix.get_default_qwix_quantization_config(
938
+ hf_config, False)
939
+
940
+ def test_llama4_success(self):
941
+ """Test Llama 4 default config path."""
942
+ hf_config = MagicMock()
943
+ hf_config.model_type = "llama4"
944
+ hf_config.quantization_config = {"quant_method": "compressed-tensors"}
945
+
946
+ result = quantize_qwix.get_default_qwix_quantization_config(
947
+ hf_config, False)
948
+ self.assertEqual(result, self.mock_llama_config)
949
+
950
+ def test_gpt_oss_success(self):
951
+ """Test GPT-OSS default config path."""
952
+ hf_config = MagicMock()
953
+ hf_config.model_type = "gpt_oss"
954
+ hf_config.quantization_config = {"quant_method": "mxfp4"}
955
+
956
+ result = quantize_qwix.get_default_qwix_quantization_config(
957
+ hf_config, False)
958
+ self.assertEqual(result, self.mock_gpt_oss_config)
959
+
960
+ def test_missing_attributes_handled(self):
961
+ """Test that function handles hf_config objects missing model_type safely."""
962
+ hf_config = object() # No attributes
963
+ result = quantize_qwix.get_default_qwix_quantization_config(
964
+ hf_config, False)
965
+ self.assertIsNone(result)
966
+
967
+
838
968
  if __name__ == '__main__':
839
969
  unittest.main()
@@ -0,0 +1,93 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ from jax import numpy as jnp
17
+ from jax._src import test_util as jtu
18
+ from jax.sharding import Mesh
19
+
20
+ from tpu_inference.layers.jax.rope import (DeepseekScalingRotaryEmbedding,
21
+ RotaryEmbedding)
22
+
23
+
24
+ class RotaryEmbeddingTest(jtu.JaxTestCase):
25
+
26
+ def test_apply_rope(self):
27
+ head_dim = 2
28
+ rope_theta = 10000
29
+ original_max_position_embeddings = 2
30
+ rope = RotaryEmbedding(
31
+ rotary_dim=head_dim,
32
+ rope_theta=rope_theta,
33
+ original_max_position_embeddings=original_max_position_embeddings,
34
+ dtype=jnp.float32)
35
+ rope.initialize_cache()
36
+ self.assertTrue(
37
+ rope.sin_cos_cache.shape == (original_max_position_embeddings,
38
+ head_dim))
39
+ expected_sin_cos = jnp.array([[1, 0], [0.5403023, 0.841471]],
40
+ dtype=jnp.float32)
41
+ self.assertArraysAllClose(rope.sin_cos_cache, expected_sin_cos)
42
+
43
+ num_tokens = 2
44
+ num_heads = 1
45
+ positions = jnp.arange(num_tokens)
46
+ x = jnp.ones((num_tokens, num_heads, head_dim))
47
+ x_rope = rope.apply_rope(positions, x)
48
+ expected_x_rope = jnp.array([[[1, 1]], [[-0.30116874, 1.3817732]]],
49
+ dtype=jnp.float32)
50
+ self.assertTrue(x_rope.shape == x.shape)
51
+ self.assertArraysAllClose(x_rope, expected_x_rope)
52
+
53
+
54
+ class DeepseekScalingRotaryEmbeddingTest(jtu.JaxTestCase):
55
+
56
+ def test_apply_rope(self):
57
+ head_dim = 2
58
+ rope_theta = 10000
59
+ original_max_position_embeddings = 1
60
+ scaling_factor = 2
61
+ devices = jax.devices()
62
+ mesh = Mesh(devices, ('data', ))
63
+
64
+ rope = DeepseekScalingRotaryEmbedding(
65
+ rotary_dim=head_dim,
66
+ rope_theta=rope_theta,
67
+ original_max_position_embeddings=original_max_position_embeddings,
68
+ scaling_factor=scaling_factor,
69
+ dtype=jnp.float32)
70
+ rope.initialize_cache(mesh)
71
+ expected_padded_dim = 128
72
+ self.assertTrue(
73
+ rope.sin_cos_cache.shape == (scaling_factor *
74
+ original_max_position_embeddings,
75
+ expected_padded_dim))
76
+
77
+ valid_cache_slice = rope.sin_cos_cache[:, :head_dim]
78
+
79
+ expected_sin_cos = jnp.array([[1.0693147, 0], [0.5777532, 0.8997973]],
80
+ dtype=jnp.float32)
81
+
82
+ self.assertArraysAllClose(valid_cache_slice, expected_sin_cos)
83
+
84
+ num_tokens = 2
85
+ num_heads = 1
86
+ positions = jnp.arange(num_tokens)
87
+ x = jnp.ones((num_tokens, num_heads, head_dim))
88
+ x_rope = rope.apply_rope(positions, x)
89
+ expected_x_rope = jnp.array(
90
+ [[[1.0693147, 1.0693147]], [[-0.32204413, 1.4775505]]],
91
+ dtype=jnp.float32)
92
+ self.assertTrue(x_rope.shape == x.shape)
93
+ self.assertArraysAllClose(x_rope, expected_x_rope)
@@ -0,0 +1,159 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from unittest.mock import MagicMock
17
+
18
+ import jax
19
+
20
+ from tpu_inference.layers.common.sharding import (Sharding, ShardingConfig,
21
+ ShardingRulesConfig,
22
+ ShardingStrategy)
23
+
24
+
25
+ class TestSharding(unittest.TestCase):
26
+ """Unit test suite for the sharding configuration logic."""
27
+
28
+ def setUp(self):
29
+ """Sets up the testing environment before each test."""
30
+
31
+ self.mock_devices = [MagicMock(coords=i) for i in range(8)]
32
+ self.original_jax_devices = jax.devices
33
+ jax.devices = lambda: self.mock_devices
34
+
35
+ def tearDown(self):
36
+ """Restores the original jax.devices function after tests."""
37
+ jax.devices = self.original_jax_devices
38
+
39
+ def test_sharding_strategy_init(self):
40
+ """Tests the initialization of the ShardingStrategy."""
41
+ strategy = ShardingStrategy(
42
+ tensor_parallelism=2,
43
+ expert_parallelism=4,
44
+ data_parallelism=1,
45
+ sequence_parallelism=1,
46
+ )
47
+ self.assertEqual(strategy.tensor_parallelism, 2)
48
+ self.assertEqual(strategy.expert_parallelism, 4)
49
+
50
+ def test_sharding_config_init(self):
51
+ """Tests the initialization of ShardingConfig."""
52
+ config = ShardingConfig()
53
+ self.assertIsInstance(config.prefill_rules, ShardingRulesConfig)
54
+ self.assertIsInstance(config.generate_rules, ShardingRulesConfig)
55
+
56
+ custom_rules = ShardingRulesConfig(activation_ffw_td=("model", None))
57
+ config_with_rules = ShardingConfig(prefill_rules=custom_rules)
58
+ self.assertEqual(config_with_rules.prefill_rules.activation_ffw_td,
59
+ ("model", None))
60
+
61
+ def test_apply_overrides(self):
62
+ """Tests the _apply_overrides method for valid and invalid keys."""
63
+ sharding = Sharding(
64
+ prefill_rules={},
65
+ generate_rules={},
66
+ )
67
+ config_obj = ShardingRulesConfig()
68
+
69
+ valid_overrides = {"activation_ffw_td": ("model", None)}
70
+ sharding._apply_overrides(config_obj, valid_overrides)
71
+ self.assertEqual(config_obj.activation_ffw_td, ("model", None))
72
+
73
+ invalid_overrides = {"non_existent_attribute": (None, "model")}
74
+ with self.assertRaises(AttributeError):
75
+ sharding._apply_overrides(config_obj, invalid_overrides)
76
+
77
+ def test_default_sharding_config(self):
78
+ """Tests that default sharding rules are created correctly."""
79
+ sharding = Sharding(
80
+ prefill_rules={},
81
+ generate_rules={},
82
+ )
83
+
84
+ sharding_cfg = sharding.get_sharding_cfg()
85
+ generate_rules = sharding_cfg.generate_rules
86
+
87
+ self.assertEqual(generate_rules.ffw_weight_df, (None, "model"))
88
+ self.assertEqual(generate_rules.moe_router_de, (None, "model"))
89
+ self.assertEqual(generate_rules.attn_q_weight_dnh,
90
+ (None, "model", None))
91
+
92
+ def test_sharding_init_with_overrides(self):
93
+ """Tests Sharding initialization with programmatic overrides."""
94
+ generate_overrides = {"logits_tv": ("data", "model")}
95
+
96
+ sharding = Sharding(
97
+ generate_rules=generate_overrides,
98
+ prefill_rules={},
99
+ )
100
+
101
+ sharding_cfg = sharding.get_sharding_cfg()
102
+ self.assertNotEqual(sharding_cfg.generate_rules.logits_tv,
103
+ (None, "model"))
104
+ self.assertEqual(sharding_cfg.generate_rules.logits_tv,
105
+ ("data", "model"))
106
+
107
+ def test_get_overrides_from_vllm_config(self):
108
+ """Tests fetching sharding overrides from a mock VllmConfig."""
109
+
110
+ mock_vllm_config_prefill = MagicMock()
111
+ mock_vllm_config_prefill.additional_config = {
112
+ "sharding": {
113
+ "logical_rules": {
114
+ "all": {
115
+ "norm_scale": ("model", )
116
+ },
117
+ "prefill": {
118
+ "activation_ffw_td": ("data", "model")
119
+ },
120
+ }
121
+ }
122
+ }
123
+ sharding_prefill = Sharding(
124
+ vllm_config=mock_vllm_config_prefill,
125
+ prefill_rules={},
126
+ generate_rules={},
127
+ )
128
+ prefill_overrides = sharding_prefill._get_overrides("prefill")
129
+
130
+ self.assertEqual(prefill_overrides["norm_scale"], ("model", ))
131
+ self.assertEqual(prefill_overrides["activation_ffw_td"],
132
+ ("data", "model"))
133
+
134
+ mock_vllm_config_generate = MagicMock()
135
+ mock_vllm_config_generate.additional_config = {
136
+ "sharding": {
137
+ "logical_rules": {
138
+ "all": {
139
+ "norm_scale": ("model", )
140
+ },
141
+ "prefill": {
142
+ "activation_ffw_td": ("data", "model")
143
+ },
144
+ }
145
+ }
146
+ }
147
+ sharding_generate = Sharding(
148
+ vllm_config=mock_vllm_config_generate,
149
+ prefill_rules={},
150
+ generate_rules={},
151
+ )
152
+ generate_overrides = sharding_generate._get_overrides("generate")
153
+
154
+ self.assertEqual(generate_overrides["norm_scale"], ("model", ))
155
+ self.assertNotIn("activation_ffw_td", generate_overrides)
156
+
157
+
158
+ if __name__ == "__main__":
159
+ unittest.main()