tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
  # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import copy
3
4
  import functools
4
5
  import os
5
6
  from typing import TYPE_CHECKING, Callable, List
@@ -34,17 +35,43 @@ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
34
35
  DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
35
36
  DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
36
37
 
37
- DEFAULT_DEEPSEEK_FP8_CONFIG = {
38
+ DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG = {
38
39
  "qwix": {
39
40
  "use_abstract_model":
40
41
  True,
41
42
  "scale_dtype":
42
43
  "bfloat16",
43
44
  "rules": [
45
+ # Exclude router from quantization
44
46
  {
45
47
  "module_path": ".*.custom_module.router.*",
46
48
  "weight_qtype": None,
47
49
  },
50
+ # Avoid the combine expert ops
51
+ {
52
+ "module_path": ".*combine_experts.*",
53
+ "weight_qtype": None,
54
+ },
55
+ # Attention layers: keep FP8 for weights and activations
56
+ {
57
+ "module_path": ".*.attn.*",
58
+ "weight_qtype": "float8_e4m3fn",
59
+ "act_qtype": "float8_e4m3fn",
60
+ },
61
+ # MoE experts: use FP4 for expert weights
62
+ {
63
+ "module_path": ".*.custom_module.*",
64
+ "weight_qtype": "float4_e2m1fn",
65
+ "act_qtype": "float8_e4m3fn",
66
+ "tile_size": 256,
67
+ },
68
+ # Shared experts: also FP4
69
+ {
70
+ "module_path": ".*.shared_experts.*",
71
+ "weight_qtype": "float4_e2m1fn",
72
+ "act_qtype": "float8_e4m3fn",
73
+ "tile_size": 256,
74
+ },
48
75
  {
49
76
  "module_path": ".*",
50
77
  "weight_qtype": "float8_e4m3fn",
@@ -166,9 +193,11 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
166
193
  head_size=kv_cache_head_size,
167
194
  mesh=mesh,
168
195
  layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
169
- cache_dtype=kv_cache_jnp_dtype)
196
+ cache_dtype=kv_cache_jnp_dtype,
197
+ use_mla=model.vllm_config.model_config.use_mla,
198
+ )
170
199
 
171
- dp_size = mesh.shape.get("data", 1) * mesh.shape.get("attn", 1)
200
+ dp_size = model.vllm_config.sharding_config.total_dp_size
172
201
 
173
202
  # NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
174
203
  input_ids = jax.random.randint(rng,
@@ -396,8 +425,7 @@ def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
396
425
 
397
426
 
398
427
  def get_default_qwix_quantization_config(
399
- model_type: str, quant_method: str,
400
- skip_quantization: bool) -> dict | None:
428
+ hf_config: dict, skip_quantization: bool) -> dict | None:
401
429
  """
402
430
  Some models are pre-quantized and in those cases, we want to return a default set of
403
431
  Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
@@ -415,9 +443,42 @@ def get_default_qwix_quantization_config(
415
443
  """
416
444
  if skip_quantization:
417
445
  return None
418
- # TODO (jacobplatin): remove this so that we can support various quantization types
446
+ model_type = hf_config.model_type.lower() if hasattr(
447
+ hf_config, "model_type") else None
448
+ quant_method = hf_config.quantization_config["quant_method"] if hasattr(
449
+ hf_config, "quantization_config") else None
450
+ # TODO (jacobplatin): remove this so that we can support various quantization types + make
451
+ # more flexible
452
+ # NOTE (jacobplatin): we'll default to mixed FP8 (attention) + FP4 (MoE experts)
453
+ # for DeepSeek
419
454
  if model_type == "deepseek_v3" and quant_method == "fp8":
420
- return DEFAULT_DEEPSEEK_FP8_CONFIG
455
+ config = copy.deepcopy(DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG)
456
+
457
+ # Dynamically fetch block size from HF config if available
458
+ # Config fmt: 'weight_block_size': [1, 512] -> we want the 2nd dim for tile_size
459
+ # NOTE: if the checkpoint is not 1D subchannel, we will throw an error
460
+ hf_quant_config = hf_config.quantization_config
461
+ assert "weight_block_size" in hf_quant_config, "Expected weight_block_size in quantization_config"
462
+ block_size = hf_quant_config["weight_block_size"]
463
+ if isinstance(block_size, (list, tuple)) and len(block_size) == 2:
464
+ assert block_size[
465
+ 0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}! If you are trying to run quantized DeepSeek, we currently only support 1D-subchannel quantization and those models can be found here: https://huggingface.co/collections/jrplatin/deepseek-r1-1d-subchannel"
466
+ tile_size = block_size[1]
467
+ assert tile_size > 1, f"Expected tile_size > 1 for DeepSeek, but got {tile_size}"
468
+ logger.info(
469
+ f"Detected DeepSeek tile_size from config: {tile_size}")
470
+
471
+ # Update tile_size in the rules, since we might not always use a 1D subchannel size of
472
+ # 256
473
+ for rule in config["qwix"]["rules"]:
474
+ if "tile_size" in rule:
475
+ rule["tile_size"] = tile_size
476
+ else:
477
+ raise ValueError(
478
+ f"Invalid weight_block_size config: {block_size}, expected a list/tuple of length 2"
479
+ )
480
+
481
+ return config
421
482
  elif model_type == "llama4" and quant_method == "compressed-tensors":
422
483
  return DEFAULT_LLAMA4_FP8_CONFIG
423
484
  # MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
@@ -436,14 +497,10 @@ def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
436
497
  # Qwix quantization config accordingly
437
498
  # NOTE: if a Qwix config is provided (via the`additional_config`), we'll
438
499
  # use that instead
439
- model_type = vllm_config.model_config.hf_config.model_type.lower(
440
- ) if hasattr(vllm_config.model_config.hf_config, "model_type") else None
441
- quant_method = vllm_config.model_config.hf_config.quantization_config[
442
- "quant_method"] if hasattr(vllm_config.model_config.hf_config,
443
- "quantization_config") else None
500
+ hf_config = vllm_config.model_config.hf_config
444
501
  default_quantization_config = get_default_qwix_quantization_config(
445
- model_type, quant_method,
446
- vllm_config.additional_config.get("skip_quantization", False))
502
+ hf_config, vllm_config.additional_config.get("skip_quantization",
503
+ False))
447
504
 
448
505
  maybe_existing_quantization_config = vllm_config.additional_config.get(
449
506
  "quantization")
@@ -500,7 +557,14 @@ def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
500
557
  maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
501
558
  weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
502
559
  else:
503
- weight = jax.random.normal(key, param_shape, dtype)
560
+ # NOTE: _uniform() in random.py does not accept float4_e2m1fn
561
+ # Error: "TypeError: uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot float4_e2m1fn."
562
+ # Workaround: call function with dtype jnp.float8_e4m3fn and cast back to float4_e2m1fn
563
+ if dtype != "float4_e2m1fn":
564
+ weight = jax.random.normal(key, param_shape, dtype)
565
+ else:
566
+ weight = jax.random.normal(key, param_shape,
567
+ jnp.float8_e4m3fn).astype(dtype)
504
568
 
505
569
  def get_slice(index):
506
570
  return weight[index]
@@ -535,18 +599,16 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
535
599
  logger.info("Initializing Qwix-quantized model with random weights...")
536
600
  # TODO (jacobplatin): clean up this logic
537
601
  scale_dtype = model.weight_loader.scale_dtype
538
- scale_shape_map = model.weight_loader.scale_shap_map_for_random_weight_loading if hasattr(
602
+ scale_shape_map = model.weight_loader.scale_shape_map_for_random_weight_loading if hasattr(
539
603
  model.weight_loader,
540
- 'scale_shap_map_for_random_weight_loading') else {}
604
+ 'scale_shape_map_for_random_weight_loading') else {}
541
605
  quantization_block_sizes = quantization_config["weight_block_size"]
542
606
  assert len(
543
607
  quantization_block_sizes
544
608
  ) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
545
- quantization_block_size_n, _ = quantization_block_sizes[
546
- 0], quantization_block_sizes[1]
547
609
 
548
610
  # Iterate through all variables and initialize them
549
- prev_param_shape = None
611
+
550
612
  for path, param in nnx.iter_graph(model):
551
613
  if not isinstance(param, nnx.Variable):
552
614
  continue
@@ -556,16 +618,17 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
556
618
  is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
557
619
  param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
558
620
  param_shape = param.value.shape
559
- # TODO (jacobplatin): clean this up
560
621
  if is_qwix_scale:
561
- param_shape = scale_shape_map.get(
562
- path[3],
563
- tuple(dim // quantization_block_size_n
564
- for dim in prev_param_shape))
622
+ key = f"{path[2]}.{path[3]}"
623
+
624
+ if key in scale_shape_map:
625
+ param_shape = scale_shape_map[key]
626
+ else:
627
+ raise ValueError(
628
+ f"Scale shape for {key} not found in scale_shape_map.")
565
629
  param.value = get_random_sharded_array(
566
630
  rng, mesh, param, param_shape, param_dtype,
567
631
  ".".join([str(x) for x in path]))
568
- prev_param_shape = param_shape
569
632
 
570
633
  # Handles the DeepSeek case, where this needs to be called to make the cache weights
571
634
  # concrete
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """Utilities for downloading model weights from HuggingFace."""
2
15
 
3
16
  import functools
@@ -67,7 +80,13 @@ def transpose_params(param_key: str, param_tensor: jax.Array, transpose_map):
67
80
  def reshape_params(param_key: str, param_tensor: jax.Array, shape_map):
68
81
  for key, new_shape in shape_map.items():
69
82
  if key in param_key:
70
- return jnp.reshape(param_tensor, new_shape)
83
+ try:
84
+ #TODO:(gpolovets) Add validation on whether reshape preserves data layout.
85
+ return jnp.reshape(param_tensor, new_shape)
86
+ except TypeError:
87
+ raise TypeError(
88
+ f"Cannot reshape for key={key}, new_shape={new_shape}, param_shape={param_tensor.shape}"
89
+ )
71
90
  return param_tensor # Base case / no-op
72
91
 
73
92
 
@@ -275,7 +294,8 @@ def _load_and_shard_weight(vllm_config,
275
294
  hf_key: str,
276
295
  hf_weight: jax.Array,
277
296
  keep_original_dtype_keys_regex: list[str]
278
- | None = None):
297
+ | None = None,
298
+ pp_missing_layers: list[str] | None = None):
279
299
  name_map = metadata_map.name_map
280
300
  reshape_keys = metadata_map.reshape_map
281
301
  bias_reshape_keys = metadata_map.bias_reshape_map
@@ -331,6 +351,10 @@ def _load_and_shard_weight(vllm_config,
331
351
  return
332
352
  model_key = name_map.get(hf_key, hf_key)
333
353
 
354
+ if pp_missing_layers and _is_pp_missing_layer(hf_key, pp_missing_layers):
355
+ logger.warning(
356
+ f"Skip loading {hf_key} as it doesn't belong to this PP stage.")
357
+ return
334
358
  model_weight, model_sharding = get_param_and_sharding(
335
359
  params, shardings, model_key)
336
360
 
@@ -394,6 +418,14 @@ def _load_and_shard_weight(vllm_config,
394
418
  model_weight.value = shard(hf_weight, spec)
395
419
 
396
420
 
421
+ def _is_pp_missing_layer(hf_key: str, pp_missing_layers: list[str]) -> bool:
422
+ has_digit = any(char.isdigit() for char in hf_key)
423
+ # add the suffix after digits to avoid it matches "layers.10" with "layers.1"
424
+ suffix = "." if has_digit else ""
425
+ return any(f'{pp_missing_layer}{suffix}' in hf_key
426
+ for pp_missing_layer in pp_missing_layers)
427
+
428
+
397
429
  def _load_hf_weights_on_thread(
398
430
  vllm_config: VllmConfig,
399
431
  params: nnx.State,
@@ -402,6 +434,7 @@ def _load_hf_weights_on_thread(
402
434
  weights_file: str,
403
435
  filter_regex: Optional[str] = None,
404
436
  keep_original_dtype_keys_regex: Optional[list[str]] = None,
437
+ pp_missing_layers: list[str] | None = None,
405
438
  ):
406
439
  """Loads weights from a single weights file."""
407
440
  try:
@@ -420,6 +453,7 @@ def _load_hf_weights_on_thread(
420
453
  hf_key,
421
454
  hf_weight,
422
455
  keep_original_dtype_keys_regex,
456
+ pp_missing_layers,
423
457
  )
424
458
 
425
459
 
@@ -431,6 +465,7 @@ def load_hf_weights(
431
465
  filter_regex: Optional[str] = None,
432
466
  is_draft_model: bool = False,
433
467
  keep_original_dtype_keys_regex: Optional[list[str]] = None,
468
+ pp_missing_layers: list[str] | None = None,
434
469
  ):
435
470
  """Load weights into a JAX model from either an iterator or files."""
436
471
  params = nnx.state(model)
@@ -461,6 +496,7 @@ def load_hf_weights(
461
496
  hf_key,
462
497
  hf_weight_jax,
463
498
  keep_original_dtype_keys_regex,
499
+ pp_missing_layers=pp_missing_layers,
464
500
  )
465
501
  else:
466
502
  # File-based path (multi-threaded)
@@ -488,6 +524,7 @@ def load_hf_weights(
488
524
  filter_regex=filter_regex,
489
525
  keep_original_dtype_keys_regex=
490
526
  keep_original_dtype_keys_regex,
527
+ pp_missing_layers=pp_missing_layers,
491
528
  ) for weights_file in weights_files
492
529
  ]
493
530
  for future in futures:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
2
16
  import functools
3
17
  from collections.abc import Sequence
@@ -23,6 +37,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
23
37
  from vllm.sequence import IntermediateTensors
24
38
 
25
39
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
+ from tpu_inference.layers.common.sharding import ShardingAxisName
26
41
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
27
42
  from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
28
43
  from tpu_inference.logger import init_logger
@@ -197,7 +212,7 @@ class VllmModelWrapper:
197
212
  kwargs={
198
213
  "input_ids": torch_view(input_ids),
199
214
  "positions": torch_view(input_positions),
200
- "intermediate_tensors": None,
215
+ "intermediate_tensors": intermediate_tensors,
201
216
  "inputs_embeds": None,
202
217
  },
203
218
  tie_weights=False,
@@ -220,8 +235,10 @@ class VllmModelWrapper:
220
235
 
221
236
  @functools.partial(
222
237
  jax.jit,
223
- out_shardings=(NamedSharding(self.mesh,
224
- PartitionSpec("data", "model"))),
238
+ out_shardings=(NamedSharding(
239
+ self.mesh,
240
+ PartitionSpec(ShardingAxisName.MLP_DATA,
241
+ ShardingAxisName.MLP_TENSOR))),
225
242
  )
226
243
  def compute_logits_func(
227
244
  params_and_buffers: Any,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from contextlib import contextmanager
2
16
  from dataclasses import dataclass
3
17
  from typing import Dict, List, Optional
@@ -1,2 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  # ruff: noqa
2
16
  from tpu_inference.platforms.tpu_platform import TpuPlatform
@@ -1,6 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
3
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
4
4
 
5
5
  import jax.numpy as jnp
6
6
  import torch
@@ -8,23 +8,25 @@ import vllm.envs as vllm_envs
8
8
  from tpu_info import device
9
9
  from vllm.inputs import ProcessorInputs, PromptType
10
10
  from vllm.platforms.interface import Platform, PlatformEnum
11
- from vllm.sampling_params import SamplingParams, SamplingType
12
11
 
13
12
  from tpu_inference import envs
14
13
  from tpu_inference.layers.common.sharding import ShardingConfigManager
15
14
  from tpu_inference.logger import init_logger
16
- from tpu_inference.utils import to_jax_dtype, to_torch_dtype
17
15
 
18
16
  if TYPE_CHECKING:
19
- from vllm.attention.backends.registry import _Backend
17
+ from vllm.attention.backends.registry import AttentionBackendEnum
18
+ from vllm.attention.selector import AttentionSelectorConfig
20
19
  from vllm.config import BlockSize, ModelConfig, VllmConfig
21
20
  from vllm.pooling_params import PoolingParams
21
+ from vllm.sampling_params import SamplingParams, SamplingType
22
22
  else:
23
23
  BlockSize = None
24
24
  ModelConfig = None
25
25
  VllmConfig = None
26
26
  PoolingParams = None
27
- _Backend = None
27
+ AttentionBackendEnum = None
28
+ SamplingParams = None
29
+ SamplingType = None
28
30
 
29
31
  logger = init_logger(__name__)
30
32
 
@@ -44,25 +46,21 @@ class TpuPlatform(Platform):
44
46
 
45
47
  additional_env_vars: list[str] = [
46
48
  "PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
47
- "TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE"
49
+ "TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE",
50
+ "NEW_MODEL_DESIGN"
48
51
  ]
49
52
 
50
53
  @classmethod
51
- def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
52
- dtype: jnp.dtype, kv_cache_dtype: Optional[str],
53
- block_size: int, use_v1: bool, use_mla: bool,
54
- has_sink: bool, use_sparse: bool,
55
- attn_type: Any) -> str:
56
- from vllm.attention.backends.registry import _Backend
57
- if selected_backend != _Backend.PALLAS:
54
+ def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
55
+ attn_selector_config: "AttentionSelectorConfig",
56
+ **kwargs) -> str:
57
+ from vllm.attention.backends.registry import AttentionBackendEnum
58
+
59
+ if selected_backend != AttentionBackendEnum.PALLAS:
58
60
  logger.info("Cannot use %s backend on TPU.", selected_backend)
59
61
 
60
- if use_v1:
61
- logger.info("Using Pallas V1 backend.")
62
- return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
63
- else:
64
- logger.info("Using Pallas backend.")
65
- return "vllm.attention.backends.pallas.PallasAttentionBackend"
62
+ logger.info("Using Pallas V1 backend.")
63
+ return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
66
64
 
67
65
  @classmethod
68
66
  def get_device_name(cls, device_id: int = 0) -> str:
@@ -146,39 +144,21 @@ class TpuPlatform(Platform):
146
144
  if compilation_config.backend == "":
147
145
  compilation_config.backend = "openxla"
148
146
 
149
- # If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
150
- impl = envs.MODEL_IMPL_TYPE
151
-
152
- # NOTE(xiang): convert dtype to jnp.dtype
153
- # NOTE(wenlong): skip this logic for mm model preprocessing
154
- # For mm model preprocessors, it may need the output dtype to be torch.
155
- # In order to avoid a PR to vLLM, we postpone the dtype checking during
156
- # tpu_worker initialization
157
- if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
158
- model_dtype = vllm_config.model_config.dtype
159
- try:
160
- dtype = to_jax_dtype(model_dtype)
161
- except ValueError:
162
- logger.warning(f"{model_dtype=} is not supported. "
163
- "Falling back to jnp.bfloat16")
164
- dtype = jnp.bfloat16
165
- if impl == "vllm":
166
- dtype = to_torch_dtype(dtype)
167
- vllm_config.model_config.dtype = dtype
168
-
169
147
  # TODO(cuiq): remove this dependency.
170
- from vllm.v1.attention.backends.pallas import PallasAttentionBackend
171
- cache_config.block_size = PallasAttentionBackend.get_page_size(
172
- vllm_config) # type: ignore[assignment]
173
- min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config)
174
- if min_page_size > cache_config.block_size:
175
- logger.warning(
176
- "Increase the page size from %s to %s to make sure there's"
177
- "no SMEM OOM",
178
- cache_config.block_size,
179
- min_page_size,
180
- )
181
- cache_config.block_size = min_page_size # type: ignore[assignment]
148
+ if vllm_config.model_config:
149
+ from vllm.v1.attention.backends.pallas import \
150
+ PallasAttentionBackend
151
+ cache_config.block_size = PallasAttentionBackend.get_page_size(
152
+ vllm_config) # type: ignore[assignment]
153
+ min_page_size = PallasAttentionBackend.get_min_page_size(
154
+ vllm_config)
155
+ if min_page_size > cache_config.block_size:
156
+ logger.warning(
157
+ "Increase the page size from %s to %s to avoid SMEM OOM",
158
+ cache_config.block_size,
159
+ min_page_size,
160
+ )
161
+ cache_config.block_size = min_page_size # type: ignore[assignment]
182
162
 
183
163
  parallel_config = vllm_config.parallel_config
184
164
  scheduler_config = vllm_config.scheduler_config
@@ -188,12 +168,12 @@ class TpuPlatform(Platform):
188
168
  multihost_backend = envs.TPU_MULTIHOST_BACKEND
189
169
  if not multihost_backend: # Single host
190
170
  if parallel_config.pipeline_parallel_size == 1:
191
- logger.info("Force using UniProcExecutor for JAX on \
192
- single host without pipeline parallelism.")
171
+ logger.info("Force using UniProcExecutor for JAX on "
172
+ "single host without pipeline parallelism.")
193
173
  parallel_config.distributed_executor_backend = "uni"
194
174
  else:
195
- logger.info("Force using MultiprocExecutor for JAX on \
196
- single host with pipeline parallelism.")
175
+ logger.info("Force using MultiprocExecutor for JAX on "
176
+ "single host with pipeline parallelism.")
197
177
  parallel_config.distributed_executor_backend = "mp"
198
178
  elif multihost_backend == "ray":
199
179
  from tpu_inference.executors.ray_distributed_executor import \
@@ -209,19 +189,21 @@ class TpuPlatform(Platform):
209
189
 
210
190
  if scheduler_config.is_multimodal_model and not \
211
191
  scheduler_config.disable_chunked_mm_input:
212
- logger.warning("TPU does not support running Multimodal models"\
213
- " without setting `--disable_chunked_mm_input`. " \
214
- "Forcing --disable_chunked_mm_input.")
192
+ logger.warning("TPU does not support running Multimodal models"
193
+ " without setting `--disable_chunked_mm_input`. "
194
+ "Forcing --disable_chunked_mm_input.")
215
195
  scheduler_config.disable_chunked_mm_input = True
216
196
 
217
197
  kv_transfer_config = vllm_config.kv_transfer_config
218
198
  if kv_transfer_config is not None:
219
199
  assert kv_transfer_config.kv_connector == "TPUConnector"
220
- # Late initialization to avoid circular import
221
- from tpu_inference.models.jax.utils.quantization.quantization_utils import \
222
- update_vllm_config_for_qwix_quantization
223
-
224
- update_vllm_config_for_qwix_quantization(vllm_config)
200
+ # Late initialization to avoid circular import.
201
+ # Only perform qwix quantization if it is jax model.
202
+ if vllm_config.model_config is not None:
203
+ from tpu_inference.models.jax.utils.qwix.qwix_utils import \
204
+ update_vllm_config_for_qwix_quantization
205
+ if vllm_config.model_config:
206
+ update_vllm_config_for_qwix_quantization(vllm_config)
225
207
 
226
208
  from tpu_inference.core.sched.dp_scheduler import \
227
209
  update_vllm_config_for_dp_scheduler
@@ -249,10 +231,11 @@ class TpuPlatform(Platform):
249
231
  def validate_request(
250
232
  cls,
251
233
  prompt: PromptType,
252
- params: Union[SamplingParams, PoolingParams],
234
+ params: Union["SamplingParams", PoolingParams],
253
235
  processed_inputs: ProcessorInputs,
254
236
  ) -> None:
255
237
  """Raises if this request is unsupported on this platform"""
238
+ from vllm.sampling_params import SamplingParams, SamplingType
256
239
 
257
240
  if isinstance(params, SamplingParams):
258
241
  if params.sampling_type == SamplingType.RANDOM_SEED:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.