tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,38 +1,35 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
3
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
4
4
 
5
5
  import jax.numpy as jnp
6
+ import torch
6
7
  import vllm.envs as vllm_envs
7
- from torchax.ops.mappings import j2t_dtype
8
8
  from tpu_info import device
9
9
  from vllm.inputs import ProcessorInputs, PromptType
10
10
  from vllm.platforms.interface import Platform, PlatformEnum
11
- from vllm.sampling_params import SamplingParams, SamplingType
12
11
 
13
12
  from tpu_inference import envs
14
13
  from tpu_inference.layers.common.sharding import ShardingConfigManager
15
14
  from tpu_inference.logger import init_logger
16
15
 
17
16
  if TYPE_CHECKING:
18
- from vllm.attention.backends.registry import _Backend
17
+ from vllm.attention.backends.registry import AttentionBackendEnum
18
+ from vllm.attention.selector import AttentionSelectorConfig
19
19
  from vllm.config import BlockSize, ModelConfig, VllmConfig
20
20
  from vllm.pooling_params import PoolingParams
21
+ from vllm.sampling_params import SamplingParams, SamplingType
21
22
  else:
22
23
  BlockSize = None
23
24
  ModelConfig = None
24
25
  VllmConfig = None
25
26
  PoolingParams = None
26
- _Backend = None
27
+ AttentionBackendEnum = None
28
+ SamplingParams = None
29
+ SamplingType = None
27
30
 
28
31
  logger = init_logger(__name__)
29
32
 
30
- _DTYPE: dict[str, jnp.dtype] = {
31
- "bfloat16": jnp.bfloat16,
32
- "float": jnp.float32,
33
- "float32": jnp.float32,
34
- }
35
-
36
33
 
37
34
  class TpuPlatform(Platform):
38
35
  _enum = PlatformEnum.TPU
@@ -49,25 +46,21 @@ class TpuPlatform(Platform):
49
46
 
50
47
  additional_env_vars: list[str] = [
51
48
  "PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
52
- "TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE"
49
+ "TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE",
50
+ "NEW_MODEL_DESIGN"
53
51
  ]
54
52
 
55
53
  @classmethod
56
- def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
57
- dtype: jnp.dtype, kv_cache_dtype: Optional[str],
58
- block_size: int, use_v1: bool, use_mla: bool,
59
- has_sink: bool, use_sparse: bool,
60
- attn_type: Any) -> str:
61
- from vllm.attention.backends.registry import _Backend
62
- if selected_backend != _Backend.PALLAS:
54
+ def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
55
+ attn_selector_config: "AttentionSelectorConfig",
56
+ **kwargs) -> str:
57
+ from vllm.attention.backends.registry import AttentionBackendEnum
58
+
59
+ if selected_backend != AttentionBackendEnum.PALLAS:
63
60
  logger.info("Cannot use %s backend on TPU.", selected_backend)
64
61
 
65
- if use_v1:
66
- logger.info("Using Pallas V1 backend.")
67
- return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
68
- else:
69
- logger.info("Using Pallas backend.")
70
- return "vllm.attention.backends.pallas.PallasAttentionBackend"
62
+ logger.info("Using Pallas V1 backend.")
63
+ return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
71
64
 
72
65
  @classmethod
73
66
  def get_device_name(cls, device_id: int = 0) -> str:
@@ -82,6 +75,14 @@ class TpuPlatform(Platform):
82
75
  logger.warning(f"Error getting device name: {e}")
83
76
  return 'TPU'
84
77
 
78
+ @classmethod
79
+ def fp8_dtype(cls) -> torch.dtype:
80
+ if cls.get_device_name().lower() == "tpu v6e":
81
+ logger.info(
82
+ "Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
83
+ return torch.float8_e5m2
84
+ return torch.float8_e4m3fn
85
+
85
86
  @classmethod
86
87
  def get_device_total_memory(cls, device_id: int = 0) -> int:
87
88
  raise NotImplementedError
@@ -132,6 +133,7 @@ class TpuPlatform(Platform):
132
133
  # For v0, the default block size is 16.
133
134
  if cache_config and cache_config.block_size is None:
134
135
  cache_config.block_size = cast(BlockSize, 16)
136
+
135
137
  compilation_config = vllm_config.compilation_config
136
138
 
137
139
  # TPU only supports DYNAMO_TRACE_ONCE compilation level
@@ -142,40 +144,21 @@ class TpuPlatform(Platform):
142
144
  if compilation_config.backend == "":
143
145
  compilation_config.backend = "openxla"
144
146
 
145
- # If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
146
- impl = envs.MODEL_IMPL_TYPE
147
-
148
- # NOTE(xiang): convert dtype to jnp.dtype
149
- # NOTE(wenlong): skip this logic for mm model preprocessing
150
- # For mm model preprocessors, it may need the output dtype to be torch.
151
- # In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
152
- if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
153
- if not isinstance(vllm_config.model_config.dtype, str):
154
- logger.warning(
155
- "The model dtype is not properly set for JAX backend. "
156
- "Overwriting it to jnp.bfloat16")
157
- vllm_config.model_config.dtype = jnp.bfloat16
158
- else:
159
- vllm_config.model_config.dtype = _DTYPE.get(
160
- vllm_config.model_config.dtype, jnp.bfloat16)
161
-
162
- if impl == "vllm":
163
- vllm_config.model_config.dtype = j2t_dtype(
164
- vllm_config.model_config.dtype.dtype)
165
-
166
147
  # TODO(cuiq): remove this dependency.
167
- from vllm.v1.attention.backends.pallas import PallasAttentionBackend
168
- cache_config.block_size = PallasAttentionBackend.get_page_size(
169
- vllm_config) # type: ignore[assignment]
170
- min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config)
171
- if min_page_size > cache_config.block_size:
172
- logger.warning(
173
- "Increase the page size from %s to %s to make sure there's"
174
- "no SMEM OOM",
175
- cache_config.block_size,
176
- min_page_size,
177
- )
178
- cache_config.block_size = min_page_size # type: ignore[assignment]
148
+ if vllm_config.model_config:
149
+ from vllm.v1.attention.backends.pallas import \
150
+ PallasAttentionBackend
151
+ cache_config.block_size = PallasAttentionBackend.get_page_size(
152
+ vllm_config) # type: ignore[assignment]
153
+ min_page_size = PallasAttentionBackend.get_min_page_size(
154
+ vllm_config)
155
+ if min_page_size > cache_config.block_size:
156
+ logger.warning(
157
+ "Increase the page size from %s to %s to avoid SMEM OOM",
158
+ cache_config.block_size,
159
+ min_page_size,
160
+ )
161
+ cache_config.block_size = min_page_size # type: ignore[assignment]
179
162
 
180
163
  parallel_config = vllm_config.parallel_config
181
164
  scheduler_config = vllm_config.scheduler_config
@@ -185,12 +168,12 @@ class TpuPlatform(Platform):
185
168
  multihost_backend = envs.TPU_MULTIHOST_BACKEND
186
169
  if not multihost_backend: # Single host
187
170
  if parallel_config.pipeline_parallel_size == 1:
188
- logger.info("Force using UniProcExecutor for JAX on \
189
- single host without pipeline parallelism.")
171
+ logger.info("Force using UniProcExecutor for JAX on "
172
+ "single host without pipeline parallelism.")
190
173
  parallel_config.distributed_executor_backend = "uni"
191
174
  else:
192
- logger.info("Force using MultiprocExecutor for JAX on \
193
- single host with pipeline parallelism.")
175
+ logger.info("Force using MultiprocExecutor for JAX on "
176
+ "single host with pipeline parallelism.")
194
177
  parallel_config.distributed_executor_backend = "mp"
195
178
  elif multihost_backend == "ray":
196
179
  from tpu_inference.executors.ray_distributed_executor import \
@@ -206,20 +189,15 @@ class TpuPlatform(Platform):
206
189
 
207
190
  if scheduler_config.is_multimodal_model and not \
208
191
  scheduler_config.disable_chunked_mm_input:
209
- logger.warning("TPU does not support running Multimodal models"\
210
- " without setting `--disable_chunked_mm_input`. " \
211
- "Forcing --disable_chunked_mm_input.")
192
+ logger.warning("TPU does not support running Multimodal models"
193
+ " without setting `--disable_chunked_mm_input`. "
194
+ "Forcing --disable_chunked_mm_input.")
212
195
  scheduler_config.disable_chunked_mm_input = True
213
196
 
214
197
  kv_transfer_config = vllm_config.kv_transfer_config
215
198
  if kv_transfer_config is not None:
216
199
  assert kv_transfer_config.kv_connector == "TPUConnector"
217
- # Late initialization to avoid circular import
218
- from tpu_inference.models.jax.utils.quantization.quantization_utils import \
219
- update_vllm_config_for_qwix_quantization
220
-
221
- update_vllm_config_for_qwix_quantization(vllm_config)
222
-
200
+ # Late initialization to avoid circular import.
223
201
  from tpu_inference.core.sched.dp_scheduler import \
224
202
  update_vllm_config_for_dp_scheduler
225
203
  update_vllm_config_for_dp_scheduler(vllm_config)
@@ -246,10 +224,11 @@ class TpuPlatform(Platform):
246
224
  def validate_request(
247
225
  cls,
248
226
  prompt: PromptType,
249
- params: Union[SamplingParams, PoolingParams],
227
+ params: Union["SamplingParams", PoolingParams],
250
228
  processed_inputs: ProcessorInputs,
251
229
  ) -> None:
252
230
  """Raises if this request is unsupported on this platform"""
231
+ from vllm.sampling_params import SamplingParams, SamplingType
253
232
 
254
233
  if isinstance(params, SamplingParams):
255
234
  if params.sampling_type == SamplingType.RANDOM_SEED:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.