tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
  # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import copy
3
4
  import functools
4
5
  import os
5
6
  from typing import TYPE_CHECKING, Callable, List
@@ -34,17 +35,43 @@ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
34
35
  DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
35
36
  DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
36
37
 
37
- DEFAULT_DEEPSEEK_FP8_CONFIG = {
38
+ DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG = {
38
39
  "qwix": {
39
40
  "use_abstract_model":
40
41
  True,
41
42
  "scale_dtype":
42
43
  "bfloat16",
43
44
  "rules": [
45
+ # Exclude router from quantization
44
46
  {
45
47
  "module_path": ".*.custom_module.router.*",
46
48
  "weight_qtype": None,
47
49
  },
50
+ # Avoid the combine expert ops
51
+ {
52
+ "module_path": ".*combine_experts.*",
53
+ "weight_qtype": None,
54
+ },
55
+ # Attention layers: keep FP8 for weights and activations
56
+ {
57
+ "module_path": ".*.attn.*",
58
+ "weight_qtype": "float8_e4m3fn",
59
+ "act_qtype": "float8_e4m3fn",
60
+ },
61
+ # MoE experts: use FP4 for expert weights
62
+ {
63
+ "module_path": ".*.custom_module.*",
64
+ "weight_qtype": "float4_e2m1fn",
65
+ "act_qtype": "float8_e4m3fn",
66
+ "tile_size": 256,
67
+ },
68
+ # Shared experts: also FP4
69
+ {
70
+ "module_path": ".*.shared_experts.*",
71
+ "weight_qtype": "float4_e2m1fn",
72
+ "act_qtype": "float8_e4m3fn",
73
+ "tile_size": 256,
74
+ },
48
75
  {
49
76
  "module_path": ".*",
50
77
  "weight_qtype": "float8_e4m3fn",
@@ -398,8 +425,7 @@ def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
398
425
 
399
426
 
400
427
  def get_default_qwix_quantization_config(
401
- model_type: str, quant_method: str,
402
- skip_quantization: bool) -> dict | None:
428
+ hf_config: dict, skip_quantization: bool) -> dict | None:
403
429
  """
404
430
  Some models are pre-quantized and in those cases, we want to return a default set of
405
431
  Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
@@ -417,9 +443,42 @@ def get_default_qwix_quantization_config(
417
443
  """
418
444
  if skip_quantization:
419
445
  return None
420
- # TODO (jacobplatin): remove this so that we can support various quantization types
446
+ model_type = hf_config.model_type.lower() if hasattr(
447
+ hf_config, "model_type") else None
448
+ quant_method = hf_config.quantization_config["quant_method"] if hasattr(
449
+ hf_config, "quantization_config") else None
450
+ # TODO (jacobplatin): remove this so that we can support various quantization types + make
451
+ # more flexible
452
+ # NOTE (jacobplatin): we'll default to mixed FP8 (attention) + FP4 (MoE experts)
453
+ # for DeepSeek
421
454
  if model_type == "deepseek_v3" and quant_method == "fp8":
422
- return DEFAULT_DEEPSEEK_FP8_CONFIG
455
+ config = copy.deepcopy(DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG)
456
+
457
+ # Dynamically fetch block size from HF config if available
458
+ # Config fmt: 'weight_block_size': [1, 512] -> we want the 2nd dim for tile_size
459
+ # NOTE: if the checkpoint is not 1D subchannel, we will throw an error
460
+ hf_quant_config = hf_config.quantization_config
461
+ assert "weight_block_size" in hf_quant_config, "Expected weight_block_size in quantization_config"
462
+ block_size = hf_quant_config["weight_block_size"]
463
+ if isinstance(block_size, (list, tuple)) and len(block_size) == 2:
464
+ assert block_size[
465
+ 0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}! If you are trying to run quantized DeepSeek, we currently only support 1D-subchannel quantization and those models can be found here: https://huggingface.co/collections/jrplatin/deepseek-r1-1d-subchannel"
466
+ tile_size = block_size[1]
467
+ assert tile_size > 1, f"Expected tile_size > 1 for DeepSeek, but got {tile_size}"
468
+ logger.info(
469
+ f"Detected DeepSeek tile_size from config: {tile_size}")
470
+
471
+ # Update tile_size in the rules, since we might not always use a 1D subchannel size of
472
+ # 256
473
+ for rule in config["qwix"]["rules"]:
474
+ if "tile_size" in rule:
475
+ rule["tile_size"] = tile_size
476
+ else:
477
+ raise ValueError(
478
+ f"Invalid weight_block_size config: {block_size}, expected a list/tuple of length 2"
479
+ )
480
+
481
+ return config
423
482
  elif model_type == "llama4" and quant_method == "compressed-tensors":
424
483
  return DEFAULT_LLAMA4_FP8_CONFIG
425
484
  # MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
@@ -438,14 +497,10 @@ def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
438
497
  # Qwix quantization config accordingly
439
498
  # NOTE: if a Qwix config is provided (via the`additional_config`), we'll
440
499
  # use that instead
441
- model_type = vllm_config.model_config.hf_config.model_type.lower(
442
- ) if hasattr(vllm_config.model_config.hf_config, "model_type") else None
443
- quant_method = vllm_config.model_config.hf_config.quantization_config[
444
- "quant_method"] if hasattr(vllm_config.model_config.hf_config,
445
- "quantization_config") else None
500
+ hf_config = vllm_config.model_config.hf_config
446
501
  default_quantization_config = get_default_qwix_quantization_config(
447
- model_type, quant_method,
448
- vllm_config.additional_config.get("skip_quantization", False))
502
+ hf_config, vllm_config.additional_config.get("skip_quantization",
503
+ False))
449
504
 
450
505
  maybe_existing_quantization_config = vllm_config.additional_config.get(
451
506
  "quantization")
@@ -502,7 +557,14 @@ def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
502
557
  maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
503
558
  weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
504
559
  else:
505
- weight = jax.random.normal(key, param_shape, dtype)
560
+ # NOTE: _uniform() in random.py does not accept float4_e2m1fn
561
+ # Error: "TypeError: uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot float4_e2m1fn."
562
+ # Workaround: call function with dtype jnp.float8_e4m3fn and cast back to float4_e2m1fn
563
+ if dtype != "float4_e2m1fn":
564
+ weight = jax.random.normal(key, param_shape, dtype)
565
+ else:
566
+ weight = jax.random.normal(key, param_shape,
567
+ jnp.float8_e4m3fn).astype(dtype)
506
568
 
507
569
  def get_slice(index):
508
570
  return weight[index]
@@ -537,18 +599,16 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
537
599
  logger.info("Initializing Qwix-quantized model with random weights...")
538
600
  # TODO (jacobplatin): clean up this logic
539
601
  scale_dtype = model.weight_loader.scale_dtype
540
- scale_shape_map = model.weight_loader.scale_shap_map_for_random_weight_loading if hasattr(
602
+ scale_shape_map = model.weight_loader.scale_shape_map_for_random_weight_loading if hasattr(
541
603
  model.weight_loader,
542
- 'scale_shap_map_for_random_weight_loading') else {}
604
+ 'scale_shape_map_for_random_weight_loading') else {}
543
605
  quantization_block_sizes = quantization_config["weight_block_size"]
544
606
  assert len(
545
607
  quantization_block_sizes
546
608
  ) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
547
- quantization_block_size_n, _ = quantization_block_sizes[
548
- 0], quantization_block_sizes[1]
549
609
 
550
610
  # Iterate through all variables and initialize them
551
- prev_param_shape = None
611
+
552
612
  for path, param in nnx.iter_graph(model):
553
613
  if not isinstance(param, nnx.Variable):
554
614
  continue
@@ -558,16 +618,17 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
558
618
  is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
559
619
  param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
560
620
  param_shape = param.value.shape
561
- # TODO (jacobplatin): clean this up
562
621
  if is_qwix_scale:
563
- param_shape = scale_shape_map.get(
564
- path[3],
565
- tuple(dim // quantization_block_size_n
566
- for dim in prev_param_shape))
622
+ key = f"{path[2]}.{path[3]}"
623
+
624
+ if key in scale_shape_map:
625
+ param_shape = scale_shape_map[key]
626
+ else:
627
+ raise ValueError(
628
+ f"Scale shape for {key} not found in scale_shape_map.")
567
629
  param.value = get_random_sharded_array(
568
630
  rng, mesh, param, param_shape, param_dtype,
569
631
  ".".join([str(x) for x in path]))
570
- prev_param_shape = param_shape
571
632
 
572
633
  # Handles the DeepSeek case, where this needs to be called to make the cache weights
573
634
  # concrete
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """Utilities for downloading model weights from HuggingFace."""
2
15
 
3
16
  import functools
@@ -281,7 +294,8 @@ def _load_and_shard_weight(vllm_config,
281
294
  hf_key: str,
282
295
  hf_weight: jax.Array,
283
296
  keep_original_dtype_keys_regex: list[str]
284
- | None = None):
297
+ | None = None,
298
+ pp_missing_layers: list[str] | None = None):
285
299
  name_map = metadata_map.name_map
286
300
  reshape_keys = metadata_map.reshape_map
287
301
  bias_reshape_keys = metadata_map.bias_reshape_map
@@ -337,6 +351,10 @@ def _load_and_shard_weight(vllm_config,
337
351
  return
338
352
  model_key = name_map.get(hf_key, hf_key)
339
353
 
354
+ if pp_missing_layers and _is_pp_missing_layer(hf_key, pp_missing_layers):
355
+ logger.warning(
356
+ f"Skip loading {hf_key} as it doesn't belong to this PP stage.")
357
+ return
340
358
  model_weight, model_sharding = get_param_and_sharding(
341
359
  params, shardings, model_key)
342
360
 
@@ -400,6 +418,14 @@ def _load_and_shard_weight(vllm_config,
400
418
  model_weight.value = shard(hf_weight, spec)
401
419
 
402
420
 
421
+ def _is_pp_missing_layer(hf_key: str, pp_missing_layers: list[str]) -> bool:
422
+ has_digit = any(char.isdigit() for char in hf_key)
423
+ # add the suffix after digits to avoid it matches "layers.10" with "layers.1"
424
+ suffix = "." if has_digit else ""
425
+ return any(f'{pp_missing_layer}{suffix}' in hf_key
426
+ for pp_missing_layer in pp_missing_layers)
427
+
428
+
403
429
  def _load_hf_weights_on_thread(
404
430
  vllm_config: VllmConfig,
405
431
  params: nnx.State,
@@ -408,6 +434,7 @@ def _load_hf_weights_on_thread(
408
434
  weights_file: str,
409
435
  filter_regex: Optional[str] = None,
410
436
  keep_original_dtype_keys_regex: Optional[list[str]] = None,
437
+ pp_missing_layers: list[str] | None = None,
411
438
  ):
412
439
  """Loads weights from a single weights file."""
413
440
  try:
@@ -426,6 +453,7 @@ def _load_hf_weights_on_thread(
426
453
  hf_key,
427
454
  hf_weight,
428
455
  keep_original_dtype_keys_regex,
456
+ pp_missing_layers,
429
457
  )
430
458
 
431
459
 
@@ -437,6 +465,7 @@ def load_hf_weights(
437
465
  filter_regex: Optional[str] = None,
438
466
  is_draft_model: bool = False,
439
467
  keep_original_dtype_keys_regex: Optional[list[str]] = None,
468
+ pp_missing_layers: list[str] | None = None,
440
469
  ):
441
470
  """Load weights into a JAX model from either an iterator or files."""
442
471
  params = nnx.state(model)
@@ -467,6 +496,7 @@ def load_hf_weights(
467
496
  hf_key,
468
497
  hf_weight_jax,
469
498
  keep_original_dtype_keys_regex,
499
+ pp_missing_layers=pp_missing_layers,
470
500
  )
471
501
  else:
472
502
  # File-based path (multi-threaded)
@@ -494,6 +524,7 @@ def load_hf_weights(
494
524
  filter_regex=filter_regex,
495
525
  keep_original_dtype_keys_regex=
496
526
  keep_original_dtype_keys_regex,
527
+ pp_missing_layers=pp_missing_layers,
497
528
  ) for weights_file in weights_files
498
529
  ]
499
530
  for future in futures:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
2
16
  import functools
3
17
  from collections.abc import Sequence
@@ -23,8 +37,10 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
23
37
  from vllm.sequence import IntermediateTensors
24
38
 
25
39
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
+ from tpu_inference.layers.common.sharding import ShardingAxisName
41
+ from tpu_inference.layers.vllm.process_weights.cleanup_sharding import \
42
+ shard_model_to_tpu
26
43
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
27
- from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
28
44
  from tpu_inference.logger import init_logger
29
45
  from tpu_inference.models.jax.jax_intermediate_tensor import \
30
46
  JaxIntermediateTensors
@@ -197,7 +213,7 @@ class VllmModelWrapper:
197
213
  kwargs={
198
214
  "input_ids": torch_view(input_ids),
199
215
  "positions": torch_view(input_positions),
200
- "intermediate_tensors": None,
216
+ "intermediate_tensors": intermediate_tensors,
201
217
  "inputs_embeds": None,
202
218
  },
203
219
  tie_weights=False,
@@ -220,8 +236,10 @@ class VllmModelWrapper:
220
236
 
221
237
  @functools.partial(
222
238
  jax.jit,
223
- out_shardings=(NamedSharding(self.mesh,
224
- PartitionSpec("data", "model"))),
239
+ out_shardings=(NamedSharding(
240
+ self.mesh,
241
+ PartitionSpec(ShardingAxisName.MLP_DATA,
242
+ ShardingAxisName.MLP_TENSOR))),
225
243
  )
226
244
  def compute_logits_func(
227
245
  params_and_buffers: Any,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from contextlib import contextmanager
2
16
  from dataclasses import dataclass
3
17
  from typing import Dict, List, Optional
@@ -1,2 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  # ruff: noqa
2
16
  from tpu_inference.platforms.tpu_platform import TpuPlatform
@@ -1,6 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
3
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
4
4
 
5
5
  import jax.numpy as jnp
6
6
  import torch
@@ -15,6 +15,7 @@ from tpu_inference.logger import init_logger
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from vllm.attention.backends.registry import AttentionBackendEnum
18
+ from vllm.attention.selector import AttentionSelectorConfig
18
19
  from vllm.config import BlockSize, ModelConfig, VllmConfig
19
20
  from vllm.pooling_params import PoolingParams
20
21
  from vllm.sampling_params import SamplingParams, SamplingType
@@ -51,11 +52,10 @@ class TpuPlatform(Platform):
51
52
 
52
53
  @classmethod
53
54
  def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
54
- head_size: int, dtype: jnp.dtype,
55
- kv_cache_dtype: Optional[str], block_size: int,
56
- use_mla: bool, has_sink: bool, use_sparse: bool,
57
- use_mm_prefix: bool, attn_type: Any) -> str:
55
+ attn_selector_config: "AttentionSelectorConfig",
56
+ **kwargs) -> str:
58
57
  from vllm.attention.backends.registry import AttentionBackendEnum
58
+
59
59
  if selected_backend != AttentionBackendEnum.PALLAS:
60
60
  logger.info("Cannot use %s backend on TPU.", selected_backend)
61
61
 
@@ -145,17 +145,20 @@ class TpuPlatform(Platform):
145
145
  compilation_config.backend = "openxla"
146
146
 
147
147
  # TODO(cuiq): remove this dependency.
148
- from vllm.v1.attention.backends.pallas import PallasAttentionBackend
149
- cache_config.block_size = PallasAttentionBackend.get_page_size(
150
- vllm_config) # type: ignore[assignment]
151
- min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config)
152
- if min_page_size > cache_config.block_size:
153
- logger.warning(
154
- "Increase the page size from %s to %s to avoid SMEM OOM",
155
- cache_config.block_size,
156
- min_page_size,
157
- )
158
- cache_config.block_size = min_page_size # type: ignore[assignment]
148
+ if vllm_config.model_config:
149
+ from vllm.v1.attention.backends.pallas import \
150
+ PallasAttentionBackend
151
+ cache_config.block_size = PallasAttentionBackend.get_page_size(
152
+ vllm_config) # type: ignore[assignment]
153
+ min_page_size = PallasAttentionBackend.get_min_page_size(
154
+ vllm_config)
155
+ if min_page_size > cache_config.block_size:
156
+ logger.warning(
157
+ "Increase the page size from %s to %s to avoid SMEM OOM",
158
+ cache_config.block_size,
159
+ min_page_size,
160
+ )
161
+ cache_config.block_size = min_page_size # type: ignore[assignment]
159
162
 
160
163
  parallel_config = vllm_config.parallel_config
161
164
  scheduler_config = vllm_config.scheduler_config
@@ -165,12 +168,12 @@ class TpuPlatform(Platform):
165
168
  multihost_backend = envs.TPU_MULTIHOST_BACKEND
166
169
  if not multihost_backend: # Single host
167
170
  if parallel_config.pipeline_parallel_size == 1:
168
- logger.info("Force using UniProcExecutor for JAX on \
169
- single host without pipeline parallelism.")
171
+ logger.info("Force using UniProcExecutor for JAX on "
172
+ "single host without pipeline parallelism.")
170
173
  parallel_config.distributed_executor_backend = "uni"
171
174
  else:
172
- logger.info("Force using MultiprocExecutor for JAX on \
173
- single host with pipeline parallelism.")
175
+ logger.info("Force using MultiprocExecutor for JAX on "
176
+ "single host with pipeline parallelism.")
174
177
  parallel_config.distributed_executor_backend = "mp"
175
178
  elif multihost_backend == "ray":
176
179
  from tpu_inference.executors.ray_distributed_executor import \
@@ -186,20 +189,15 @@ class TpuPlatform(Platform):
186
189
 
187
190
  if scheduler_config.is_multimodal_model and not \
188
191
  scheduler_config.disable_chunked_mm_input:
189
- logger.warning("TPU does not support running Multimodal models"\
190
- " without setting `--disable_chunked_mm_input`. " \
191
- "Forcing --disable_chunked_mm_input.")
192
+ logger.warning("TPU does not support running Multimodal models"
193
+ " without setting `--disable_chunked_mm_input`. "
194
+ "Forcing --disable_chunked_mm_input.")
192
195
  scheduler_config.disable_chunked_mm_input = True
193
196
 
194
197
  kv_transfer_config = vllm_config.kv_transfer_config
195
198
  if kv_transfer_config is not None:
196
199
  assert kv_transfer_config.kv_connector == "TPUConnector"
197
- # Late initialization to avoid circular import
198
- from tpu_inference.models.jax.utils.quantization.quantization_utils import \
199
- update_vllm_config_for_qwix_quantization
200
-
201
- update_vllm_config_for_qwix_quantization(vllm_config)
202
-
200
+ # Late initialization to avoid circular import.
203
201
  from tpu_inference.core.sched.dp_scheduler import \
204
202
  update_vllm_config_for_dp_scheduler
205
203
  update_vllm_config_for_dp_scheduler(vllm_config)
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import time
2
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
3
17
 
@@ -32,6 +46,8 @@ class CompilationManager:
32
46
 
33
47
  def __init__(self, runner: "TPUModelRunner"):
34
48
  self.runner = runner
49
+ self._sampling_precompiled = False
50
+ self._gather_logprobs_precompiled = False
35
51
  if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
36
52
  logger.info("Enabling JAX compile cache.")
37
53
  jax.config.update("jax_compilation_cache_dir",
@@ -86,9 +102,13 @@ class CompilationManager:
86
102
  return
87
103
  self._precompile_select_from_array()
88
104
  self._precompile_compute_logits()
105
+ # Skip sampling if already precompiled before KV cache allocation
106
+ if not self._sampling_precompiled:
107
+ self._precompile_sampling()
89
108
  self._precompile_disagg_utils()
90
- self._precompile_sampling()
91
- self._precompile_gather_logprobs()
109
+ # Skip gather_logprobs if already precompiled before KV cache allocation
110
+ if not self._gather_logprobs_precompiled:
111
+ self._precompile_gather_logprobs()
92
112
  self._precompile_structured_decoding()
93
113
  if self.runner.speculative_config:
94
114
  self._precompile_speculative_decoding()
@@ -107,7 +127,7 @@ class CompilationManager:
107
127
 
108
128
  self._run_compilation(
109
129
  "input_embeddings_merger",
110
- self.runner.get_input_embeddings_fn,
130
+ self.runner.embed_input_ids_fn,
111
131
  self.runner.state,
112
132
  dummy_input_ids,
113
133
  dummy_multimodal_embeddings,
@@ -116,7 +136,7 @@ class CompilationManager:
116
136
 
117
137
  self._run_compilation(
118
138
  "input_embeddings_merger_text_only",
119
- self.runner.get_input_embeddings_fn,
139
+ self.runner.embed_input_ids_fn,
120
140
  self.runner.state,
121
141
  dummy_input_ids,
122
142
  None,
@@ -475,35 +495,39 @@ class CompilationManager:
475
495
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
476
496
  logits_sharding)
477
497
  for do_sampling in (True, False):
478
- if do_sampling:
479
- temperature = np.full((num_reqs, ), 0.7, dtype=np.float32)
480
- top_k = np.full((num_reqs, ), 20, dtype=np.int32)
481
- top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
482
- (temperature, top_k,
483
- top_p) = device_array(self.runner.mesh,
484
- (temperature, top_k, top_p),
485
- sharding=sampling_metadata_sharding)
486
- else:
487
- temperature = None
488
- top_k = None
489
- top_p = None
490
-
491
- sampling_metadata = TPUSupportedSamplingMetadata(
492
- temperature=temperature,
493
- top_k=top_k,
494
- top_p=top_p,
495
- do_sampling=do_sampling,
496
- )
497
- self._run_compilation(
498
- f"worker{self.runner.rank} sample",
499
- sample,
500
- self.runner.rng_params_for_sampling,
501
- self.runner.mesh,
502
- logits,
503
- sampling_metadata,
504
- num_reqs=num_reqs,
505
- do_sampling=do_sampling,
506
- )
498
+ for logprobs in (True, False):
499
+ if do_sampling:
500
+ temperature = np.full((num_reqs, ),
501
+ 0.7,
502
+ dtype=np.float32)
503
+ top_k = np.full((num_reqs, ), 20, dtype=np.int32)
504
+ top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
505
+ (temperature, top_k, top_p) = device_array(
506
+ self.runner.mesh, (temperature, top_k, top_p),
507
+ sharding=sampling_metadata_sharding)
508
+ else:
509
+ temperature = None
510
+ top_k = None
511
+ top_p = None
512
+
513
+ sampling_metadata = TPUSupportedSamplingMetadata(
514
+ temperature=temperature,
515
+ top_k=top_k,
516
+ top_p=top_p,
517
+ do_sampling=do_sampling,
518
+ logprobs=logprobs)
519
+ self._run_compilation(
520
+ f"worker{self.runner.rank} sample",
521
+ sample,
522
+ self.runner.rng_params_for_sampling,
523
+ self.runner.mesh,
524
+ logits,
525
+ sampling_metadata,
526
+ num_reqs=num_reqs,
527
+ do_sampling=do_sampling,
528
+ )
529
+
530
+ self._sampling_precompiled = True
507
531
 
508
532
  def _precompile_disagg_utils(self) -> None:
509
533
  if not is_disagg_enabled():
@@ -533,8 +557,16 @@ class CompilationManager:
533
557
  logger.info("Compiling gather_logprobs with different input shapes.")
534
558
  hsize = self.runner.model_config.get_vocab_size()
535
559
  for num_reqs in self.runner.num_reqs_paddings:
536
- logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
537
- token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
560
+ logits_sharding = NamedSharding(
561
+ self.runner.mesh,
562
+ PartitionSpec(ShardingAxisName.MLP_DATA,
563
+ ShardingAxisName.MLP_TENSOR))
564
+ token_ids_sharding = NamedSharding(
565
+ self.runner.mesh, PartitionSpec(ShardingAxisName.MLP_DATA, ))
566
+ logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
567
+ logits_sharding)
568
+ token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32,
569
+ token_ids_sharding)
538
570
  self._run_compilation(
539
571
  f"worker{self.runner.rank} gather_logprobs",
540
572
  self.runner._compute_and_gather_logprobs,
@@ -544,6 +576,8 @@ class CompilationManager:
544
576
  num_reqs=num_reqs,
545
577
  )
546
578
 
579
+ self._gather_logprobs_precompiled = True
580
+
547
581
  def _precompile_speculative_decoding(self) -> None:
548
582
  logger.info(
549
583
  "Compiling speculative_decoding with different input shapes.")