tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,311 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from unittest import mock
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ import pytest
10
+ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
11
+ ParallelConfig, SchedulerConfig, SpeculativeConfig,
12
+ VllmConfig)
13
+ from vllm.config.load import LoadConfig
14
+
15
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
16
+ from tpu_inference.runner import utils as runner_utils
17
+ from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
18
+
19
+ # Use a real model dir for config, but we will mock model loading/execution
20
+ model_dir = "meta-llama/Llama-3.1-8B-Instruct"
21
+ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
22
+
23
+
24
+ def _create_proposer(
25
+ method: str,
26
+ num_speculative_tokens: int,
27
+ ) -> Eagle3Proposer:
28
+ model_config = ModelConfig(model=model_dir,
29
+ runner="generate",
30
+ max_model_len=8192,
31
+ seed=42)
32
+
33
+ speculative_config = SpeculativeConfig(
34
+ target_model_config=model_config,
35
+ target_parallel_config=ParallelConfig(),
36
+ model=eagle3_dir,
37
+ method=method,
38
+ num_speculative_tokens=num_speculative_tokens,
39
+ )
40
+
41
+ vllm_config = VllmConfig(model_config=model_config,
42
+ cache_config=CacheConfig(block_size=16),
43
+ speculative_config=speculative_config,
44
+ device_config=DeviceConfig(device="tpu"),
45
+ parallel_config=ParallelConfig(
46
+ pipeline_parallel_size=1,
47
+ tensor_parallel_size=1),
48
+ load_config=LoadConfig(),
49
+ scheduler_config=SchedulerConfig(
50
+ max_num_batched_tokens=8192,
51
+ max_num_seqs=128,
52
+ max_model_len=model_config.max_model_len,
53
+ is_encoder_decoder=False))
54
+
55
+ # Mock the runner, as the proposer needs it for initialization
56
+ mock_runner = mock.MagicMock()
57
+ # Create a real mesh for testing sharding-related logic
58
+ devices = np.array(jax.devices())
59
+ mock_runner.mesh = jax.sharding.Mesh(devices, axis_names=('model', ))
60
+ mock_runner.max_num_tokens = 8192
61
+ mock_runner.max_model_len = 8192
62
+ mock_runner.kv_cache_config.kv_cache_groups = [mock.MagicMock()]
63
+ mock_runner.input_batch = mock.MagicMock()
64
+
65
+ return Eagle3Proposer(vllm_config=vllm_config, runner=mock_runner)
66
+
67
+
68
+ def test_prepare_inputs():
69
+ """
70
+ Mirrors the GPU test for prepare_inputs, adapted for JAX.
71
+ - cu_target_query_lens: [0, a, a + b, a + b + c]
72
+ - num_rejected_tokens: [n1, n2, n3]
73
+ - num_tokens_per_req: [a - n1, b - n2, c - n3]
74
+ - cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
75
+ - token_indices: [0, ..., a - n1 - 1, a, ..., a + b - n2 - 1, ...]
76
+ """
77
+ proposer = _create_proposer("eagle3", 1)
78
+ num_reqs = 3
79
+ max_num_seqs = 128
80
+ max_num_blocks_per_req = 10 # Mock value
81
+
82
+ # Mock runner attributes
83
+ proposer.runner.input_batch.num_reqs = num_reqs
84
+ proposer.runner.num_tokens_paddings = runner_utils.get_token_paddings(
85
+ min_token_size=16, max_token_size=1024, padding_gap=0)
86
+
87
+ # Mocks required by _prepare_draft_inputs helper
88
+ proposer.combine_hidden_states_fn = lambda state, h: h # Mock passthrough
89
+ proposer.state = None # Mock state
90
+ proposer.runner.input_batch.block_table = [mock.MagicMock()]
91
+ # Mock the block table return value (2D array)
92
+ (proposer.runner.input_batch.block_table[0].get_cpu_tensor.return_value
93
+ ) = jnp.zeros((num_reqs, max_num_blocks_per_req), dtype=jnp.int32)
94
+
95
+ # --- Setup sequence data ---
96
+ qsl_cpu = np.zeros(max_num_seqs + 1, dtype=np.int32)
97
+ query_lens = np.zeros(max_num_seqs, dtype=np.int32)
98
+ query_lens[:num_reqs] = [4, 7, 5]
99
+ qsl_cpu[1:] = np.cumsum(query_lens)
100
+
101
+ sl_cpu = np.zeros(max_num_seqs, dtype=np.int32)
102
+ sl_cpu[:num_reqs] = [4, 7, 5]
103
+
104
+ # Inputs
105
+ total_tokens = 16
106
+ hidden_size = 128
107
+ # The input_ids should be large enough to be indexed by token_indices,
108
+ # which can access up to total_tokens for padded requests.
109
+ input_ids = jnp.arange(total_tokens + 1)
110
+ aux_hidden_states = (jnp.ones((total_tokens + 1, hidden_size)),
111
+ jnp.ones((total_tokens + 1, hidden_size)),
112
+ jnp.ones((total_tokens + 1, hidden_size)))
113
+
114
+ num_rejected_tokens_cpu = np.zeros(max_num_seqs, dtype=np.int32)
115
+ num_rejected_tokens_cpu[:num_reqs] = [1, 3, 2]
116
+ num_rejected_tokens = jnp.array(num_rejected_tokens_cpu)
117
+ # This is only used in the _prepare_input_ids helper
118
+ # It must be padded to max_num_seqs (128) to match the mask in jnp.where
119
+ next_token_ids_cpu = np.zeros(max_num_seqs, dtype=np.int32)
120
+ next_token_ids_cpu[:num_reqs] = [1, 2, 3] # Valid tokens for active reqs
121
+ next_token_ids = jnp.array(next_token_ids_cpu)
122
+
123
+ attn_metadata = AttentionMetadata(
124
+ seq_lens=jnp.array(sl_cpu),
125
+ input_positions=jnp.arange(total_tokens),
126
+ query_start_loc=jnp.array(qsl_cpu),
127
+ block_tables=jnp.array([]), # This will be replaced by the mock
128
+ request_distribution=None,
129
+ )
130
+ attn_metadata.query_start_loc_cpu = qsl_cpu
131
+ attn_metadata.seq_lens_cpu = sl_cpu
132
+
133
+ # Expected results
134
+ expected_new_qsl = np.zeros(max_num_seqs + 1, dtype=np.int32)
135
+ num_tokens_per_req = np.zeros(max_num_seqs, dtype=np.int32)
136
+ num_tokens_per_req[:num_reqs] = [3, 4, 3]
137
+ # The implementation sets padded query lengths to 1, and rejected tokens
138
+ # are 0 for padded requests.
139
+ num_tokens_per_req[num_reqs:] = 1
140
+ expected_new_qsl[1:] = np.cumsum(num_tokens_per_req)
141
+
142
+ expected_new_seq_lens = np.zeros(max_num_seqs, dtype=np.int32)
143
+ expected_new_seq_lens[:num_reqs] = [3, 4, 3]
144
+
145
+ expected_total_tokens = int(expected_new_qsl[-1])
146
+ expected_total_tokens = runner_utils.get_padded_token_len(
147
+ proposer.runner.num_tokens_paddings, expected_total_tokens)
148
+
149
+ expected_last_token_indices = jnp.array(expected_new_qsl[1:] - 1)
150
+
151
+ # Execute
152
+ target_hidden_states, input_ids, last_token_indices, updated_metadata = (
153
+ proposer.prepare_inputs(attn_metadata, input_ids, aux_hidden_states,
154
+ next_token_ids, num_rejected_tokens))
155
+
156
+ # Assertions
157
+ assert jnp.array_equal(updated_metadata.query_start_loc,
158
+ jnp.array(expected_new_qsl))
159
+ assert jnp.array_equal(updated_metadata.seq_lens,
160
+ jnp.array(expected_new_seq_lens))
161
+
162
+ assert jnp.array_equal(last_token_indices, expected_last_token_indices)
163
+
164
+ assert input_ids.shape == (expected_total_tokens, )
165
+ # NOTE: We don't check the content of target_token_ids for padded requests
166
+ # as it's complicated to construct the expected tensor. The shape check
167
+ # and the qsl/seq_len checks are sufficient to validate the logic.
168
+ # The concatenated hidden state shape should be (..., hidden_size * 3)
169
+ assert target_hidden_states.shape == (expected_total_tokens,
170
+ hidden_size * 3)
171
+
172
+
173
+ @pytest.mark.parametrize("method", ["eagle3"])
174
+ @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
175
+ def test_propose(method, num_speculative_tokens):
176
+ proposer = _create_proposer(method, num_speculative_tokens)
177
+
178
+ # Mock the JAX model functions
179
+ hidden_size = 128
180
+ vocab_size = 100
181
+ batch_size = 2
182
+ seq_len_1 = 5
183
+ seq_len_2 = 3
184
+ total_tokens = seq_len_1 + seq_len_2
185
+ base_token_ids = [42, 60]
186
+
187
+ def mock_model_fn(state, kv_caches, input_ids, target_hidden_states,
188
+ attn_metadata):
189
+ """
190
+ Mock model_fn.
191
+ Returns: (kv_caches, hidden_states_for_logits, residual_tuple)
192
+
193
+ - On first call (num_tokens == total_tokens):
194
+ Populate hidden_states_for_logits[last_token_indices] with base_token_ids.
195
+ Populate residual_tuple[0][last_token_indices] with base_token_ids.
196
+ - On loop calls (num_tokens == batch_size):
197
+ Use input_ids (previous draft token) to generate new token (input_ids + 1).
198
+ Populate hidden_states_for_logits with (input_ids + 1).
199
+ Populate residual_tuple[0] with (input_ids + 1).
200
+ """
201
+ num_tokens = input_ids.shape[0]
202
+
203
+ # This will be used for logits (output 2)
204
+ hidden_states_for_logits = jnp.zeros((num_tokens, hidden_size))
205
+ # This will be fed into the next step (output 3, item 0)
206
+ residual_hidden_states = jnp.zeros((num_tokens, hidden_size))
207
+
208
+ if num_tokens == total_tokens:
209
+ # First call in propose.
210
+ # `propose` will select from last_token_indices.
211
+ last_token_indices = attn_metadata.query_start_loc[1:] - 1
212
+
213
+ # Set logits output
214
+ hidden_states_for_logits = hidden_states_for_logits.at[
215
+ last_token_indices, 0].set(jnp.array(base_token_ids))
216
+
217
+ # Set residual for next step
218
+ residual_hidden_states = residual_hidden_states.at[
219
+ last_token_indices, 0].set(jnp.array(base_token_ids))
220
+ else:
221
+ # Subsequent calls in the loop
222
+ # input_ids is the previous draft token (shape `batch_size`)
223
+ # Mock logic: next token = previous token + 1
224
+ next_token_ids_encoded = input_ids + 1
225
+
226
+ # Set logits output
227
+ hidden_states_for_logits = hidden_states_for_logits.at[:, 0].set(
228
+ next_token_ids_encoded)
229
+
230
+ # Set residual for next step
231
+ residual_hidden_states = residual_hidden_states.at[:, 0].set(
232
+ next_token_ids_encoded)
233
+
234
+ # Return (kv_caches, hidden_states, residual_tuple)
235
+ return kv_caches, hidden_states_for_logits, (residual_hidden_states, )
236
+
237
+ def mock_compute_logits_fn(state, hidden_states, lora_metadata):
238
+ # Create deterministic logits from hidden_states.
239
+ # Takes the value from hidden_states[:, 0]
240
+ token_ids = hidden_states[:, 0].astype(jnp.int32)
241
+ return jax.nn.one_hot(token_ids, vocab_size)
242
+
243
+ def mock_combine_hidden_states_fn(state, hidden_states):
244
+ # Passthrough, as the mock doesn't need combination.
245
+ return hidden_states
246
+
247
+ proposer.model_fn = mock_model_fn
248
+ proposer.compute_logits_fn = mock_compute_logits_fn
249
+ proposer.combine_hidden_states_fn = mock_combine_hidden_states_fn
250
+ proposer.state = None # Mock state
251
+
252
+ # Inputs
253
+ kv_caches = [None] * 1 # Mock kv_caches
254
+
255
+ # Create the 2D table first, as this is what the (unused) mock expects
256
+ block_tables_2d = jnp.zeros((batch_size, 10), dtype=jnp.int32)
257
+
258
+ attn_metadata = AttentionMetadata(
259
+ seq_lens=jnp.array([seq_len_1, seq_len_2]),
260
+ input_positions=jnp.concatenate(
261
+ [jnp.arange(seq_len_1),
262
+ jnp.arange(seq_len_2)]),
263
+ query_start_loc=jnp.array([0, seq_len_1, total_tokens]),
264
+ # Pass the FLATTENED table to simulate output of prepare_inputs
265
+ block_tables=block_tables_2d.reshape(-1),
266
+ request_distribution=None,
267
+ )
268
+
269
+ # These are the inputs to `propose`
270
+ # input_ids (from prepare_inputs)
271
+ target_token_ids = jnp.zeros(total_tokens, dtype=jnp.int32)
272
+ # target_hidden_states (from prepare_inputs)
273
+ target_hidden_states = jnp.zeros((total_tokens, hidden_size))
274
+ # last_token_indices (from prepare_inputs)
275
+ last_token_indices = attn_metadata.query_start_loc[1:] - 1
276
+
277
+ # Mock runner for block tables
278
+ # This mock isn't actually used by propose(), but we'll set it
279
+ # to the 2D table for correctness, as that's what
280
+ # _prepare_draft_inputs (called by prepare_inputs) would expect.
281
+ proposer.runner.input_batch.num_reqs = batch_size
282
+ proposer.runner.input_batch.block_table = [mock.MagicMock()]
283
+ (proposer.runner.input_batch.block_table[0].get_device_tensor.return_value
284
+ ) = block_tables_2d
285
+
286
+ # Execute
287
+ _, draft_token_ids = proposer.propose(
288
+ kv_caches,
289
+ target_token_ids,
290
+ attn_metadata,
291
+ last_token_indices,
292
+ target_hidden_states,
293
+ )
294
+
295
+ if draft_token_ids.ndim == 1:
296
+ draft_token_ids = jnp.expand_dims(draft_token_ids, axis=-1)
297
+ # Assertions
298
+ assert draft_token_ids.shape == (batch_size, num_speculative_tokens)
299
+
300
+ # Check the generated tokens
301
+ # Step 0: base_token_ids [42, 60]
302
+ # Step 1: [43, 61]
303
+ # Step 2: [44, 62]
304
+ # ...
305
+ expected_tokens = np.zeros((batch_size, num_speculative_tokens),
306
+ dtype=np.int64)
307
+ for i in range(batch_size):
308
+ for j in range(num_speculative_tokens):
309
+ expected_tokens[i, j] = base_token_ids[i] + j
310
+
311
+ assert jnp.array_equal(draft_token_ids, jnp.array(expected_tokens))
tests/test_base.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import logging
2
16
  import unittest
3
17
  import warnings
tests/test_envs.py CHANGED
@@ -60,6 +60,7 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
60
60
  monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
61
61
  monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
62
62
  monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
63
+ monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "0")
63
64
  monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
64
65
 
65
66
  # Test SKIP_JAX_PRECOMPILE (default False)
@@ -86,6 +87,82 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
86
87
  monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
87
88
  assert envs.USE_MOE_EP_KERNEL is True
88
89
 
90
+ # Test ENABLE_QUANTIZED_MATMUL_KERNEL (default False)
91
+ assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is False
92
+ monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "1")
93
+ assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is True
94
+
95
+
96
+ def test_boolean_env_vars_string_values(monkeypatch: pytest.MonkeyPatch):
97
+ """Test that boolean env vars accept string values like 'True' and 'False'"""
98
+
99
+ # Test NEW_MODEL_DESIGN with string "True"
100
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "True")
101
+ assert envs.NEW_MODEL_DESIGN is True
102
+
103
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "true")
104
+ assert envs.NEW_MODEL_DESIGN is True
105
+
106
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "False")
107
+ assert envs.NEW_MODEL_DESIGN is False
108
+
109
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "false")
110
+ assert envs.NEW_MODEL_DESIGN is False
111
+
112
+ # Test SKIP_JAX_PRECOMPILE with string values
113
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "True")
114
+ assert envs.SKIP_JAX_PRECOMPILE is True
115
+
116
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "false")
117
+ assert envs.SKIP_JAX_PRECOMPILE is False
118
+
119
+ # Test VLLM_XLA_CHECK_RECOMPILATION with string values
120
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "TRUE")
121
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
122
+
123
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "FALSE")
124
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
125
+
126
+ # Test USE_MOE_EP_KERNEL with string values
127
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "true")
128
+ assert envs.USE_MOE_EP_KERNEL is True
129
+
130
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "False")
131
+ assert envs.USE_MOE_EP_KERNEL is False
132
+
133
+
134
+ def test_boolean_env_vars_invalid_values(monkeypatch: pytest.MonkeyPatch):
135
+ """Test that boolean env vars raise errors for invalid values"""
136
+
137
+ # Test invalid value for NEW_MODEL_DESIGN
138
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "yes")
139
+ with pytest.raises(
140
+ ValueError,
141
+ match="Invalid boolean value 'yes' for NEW_MODEL_DESIGN"):
142
+ _ = envs.NEW_MODEL_DESIGN
143
+
144
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "2")
145
+ with pytest.raises(ValueError,
146
+ match="Invalid boolean value '2' for NEW_MODEL_DESIGN"):
147
+ _ = envs.NEW_MODEL_DESIGN
148
+
149
+ # Test invalid value for SKIP_JAX_PRECOMPILE
150
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "invalid")
151
+ with pytest.raises(
152
+ ValueError,
153
+ match="Invalid boolean value 'invalid' for SKIP_JAX_PRECOMPILE"):
154
+ _ = envs.SKIP_JAX_PRECOMPILE
155
+
156
+
157
+ def test_boolean_env_vars_empty_string(monkeypatch: pytest.MonkeyPatch):
158
+ """Test that empty string returns default value"""
159
+
160
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "")
161
+ assert envs.NEW_MODEL_DESIGN is False # Should return default
162
+
163
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "")
164
+ assert envs.SKIP_JAX_PRECOMPILE is False # Should return default
165
+
89
166
 
90
167
  def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
91
168
  # Ensure clean environment for integer vars by setting to defaults
@@ -179,7 +256,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
179
256
 
180
257
  def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
181
258
  monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
182
- assert envs.MODEL_IMPL_TYPE == "flax_nnx"
259
+ assert envs.MODEL_IMPL_TYPE == "auto"
183
260
 
184
261
 
185
262
  def test_cache_preserves_values_across_env_changes(
tests/test_tpu_info.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import os
2
16
  from unittest.mock import MagicMock, patch
3
17
 
tests/test_utils.py CHANGED
@@ -9,7 +9,7 @@ import pytest
9
9
  from tpu_inference.utils import (GBYTES, enable_megacore,
10
10
  get_jax_dtype_from_str_dtype, get_megacore,
11
11
  get_padded_head_dim, hbm_usage_bytes,
12
- hbm_usage_gb, quantize_kv)
12
+ hbm_usage_gb)
13
13
 
14
14
 
15
15
  def test_enable_and_get_megacore():
@@ -182,48 +182,6 @@ def test_get_padded_head_dim(head_dim, expected_padded_head_dim):
182
182
  assert get_padded_head_dim(head_dim) == expected_padded_head_dim
183
183
 
184
184
 
185
- def test_quantize_kv_float8_e4m3fn():
186
- """Tests the quantize_kv function with float8_e4m3fn dtype."""
187
- key = jnp.array([-1.0, 0.5, 1.0, 1.5])
188
- value = jnp.array([2.0, 0.0, -2.0, -3.0])
189
- kv_cache_quantized_dtype = jnp.float8_e4m3fn
190
- k_scale = 0.1
191
- v_scale = 0.2
192
-
193
- quantized_key, quantized_value = quantize_kv(key, value,
194
- kv_cache_quantized_dtype,
195
- k_scale, v_scale)
196
-
197
- # Expected key: key / k_scale -> clip -> astype
198
- # [-10., 5., 10., 15.] are within float8_e4m3fn range
199
- expected_key = jnp.array([-10.0, 5.0, 10.0, 15.0], dtype=jnp.float8_e4m3fn)
200
-
201
- # Expected value: value / v_scale -> clip -> astype
202
- # [10., 0., -10., -15.] are within float8_e4m3fn range
203
- expected_value = jnp.array([10.0, 0.0, -10.0, -15.0],
204
- dtype=jnp.float8_e4m3fn)
205
-
206
- assert jnp.array_equal(quantized_key, expected_key)
207
- assert jnp.array_equal(quantized_value, expected_value)
208
-
209
- # Test clipping
210
- dtype_info = jnp.finfo(kv_cache_quantized_dtype)
211
- minval, maxval = float(dtype_info.min), float(dtype_info.max)
212
-
213
- # Values that will be outside the range after scaling
214
- key_clip = jnp.array([minval * k_scale * 2, maxval * k_scale * 2])
215
- value_clip = jnp.array([maxval * v_scale * 2, minval * v_scale * 2])
216
- quantized_key_clip, quantized_value_clip = quantize_kv(
217
- key_clip, value_clip, kv_cache_quantized_dtype, k_scale, v_scale)
218
-
219
- # Values should be clipped to the min/max of the float8 dtype
220
- expected_key_clip = jnp.array([minval, maxval], dtype=jnp.float8_e4m3fn)
221
- expected_value_clip = jnp.array([maxval, minval], dtype=jnp.float8_e4m3fn)
222
-
223
- assert jnp.array_equal(quantized_key_clip, expected_key_clip)
224
- assert jnp.array_equal(quantized_value_clip, expected_value_clip)
225
-
226
-
227
185
  def test_get_jax_dtype_from_str_dtype():
228
186
  """
229
187
  Test the get_jax_dtype_from_str_dtype function
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.