tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,90 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import torch
22
+ from jax.sharding import Mesh
23
+
24
+ from tpu_inference.runner.input_batch import InputBatch
25
+ from tpu_inference.utils import device_array
26
+
27
+ DEFAULT_SAMPLING_PARAMS = dict(
28
+ temperature=-1.0,
29
+ top_k=0,
30
+ top_p=1.0,
31
+ )
32
+
33
+
34
+ @functools.partial(
35
+ jax.tree_util.register_dataclass,
36
+ data_fields=[
37
+ "temperature",
38
+ "top_k",
39
+ "top_p",
40
+ ],
41
+ meta_fields=["do_sampling", "logprobs"],
42
+ )
43
+ @dataclass
44
+ class TPUSupportedSamplingMetadata:
45
+ temperature: Optional[jnp.ndarray] = None
46
+ top_k: Optional[jnp.ndarray] = None
47
+ top_p: Optional[jnp.ndarray] = None
48
+ do_sampling: bool = False
49
+ logprobs: bool = False
50
+
51
+ @classmethod
52
+ def from_input_batch(
53
+ cls,
54
+ mesh: Mesh,
55
+ input_batch: InputBatch,
56
+ padded_num_reqs: int,
57
+ sharding: Optional[jax.sharding.Sharding] = None,
58
+ ) -> "TPUSupportedSamplingMetadata":
59
+ needs_logprobs = input_batch.max_num_logprobs > 0 if input_batch.max_num_logprobs else False
60
+ if input_batch.all_greedy:
61
+ return cls(do_sampling=False, logprobs=needs_logprobs)
62
+ num_reqs = input_batch.num_reqs
63
+
64
+ def fill_slice(cpu_torch_tensor: torch.Tensor,
65
+ fill_val: float) -> torch.Tensor:
66
+ # Pad value is the default one.
67
+ cpu_torch_tensor[num_reqs:padded_num_reqs] = fill_val
68
+ return cpu_torch_tensor
69
+
70
+ temp_tensor = fill_slice(input_batch.temperature_cpu,
71
+ DEFAULT_SAMPLING_PARAMS["temperature"])
72
+ top_k_tensor = fill_slice(input_batch.top_k_cpu,
73
+ DEFAULT_SAMPLING_PARAMS["top_k"])
74
+ top_p_tensor = fill_slice(input_batch.top_p_cpu,
75
+ DEFAULT_SAMPLING_PARAMS["top_p"])
76
+
77
+ # Slice persistent device tensors to a fixed pre-compiled padded shape.
78
+ return cls(
79
+ temperature=device_array(mesh,
80
+ temp_tensor[:padded_num_reqs],
81
+ sharding=sharding),
82
+ top_p=device_array(mesh,
83
+ top_p_tensor[:padded_num_reqs],
84
+ sharding=sharding),
85
+ top_k=device_array(mesh,
86
+ top_k_tensor[:padded_num_reqs],
87
+ sharding=sharding),
88
+ do_sampling=not input_batch.all_greedy,
89
+ logprobs=needs_logprobs,
90
+ )
@@ -0,0 +1,121 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Optional, Tuple
17
+
18
+ # Flax and JAX sharding imports
19
+ import jax
20
+ from flax import nnx
21
+
22
+ from tpu_inference.layers.jax.attention.attention import (AttentionMetadata,
23
+ KVCache)
24
+ from tpu_inference.layers.jax.layers import DenseFFW
25
+ from tpu_inference.layers.jax.moe.moe import MoE
26
+
27
+
28
+ @dataclass(kw_only=True)
29
+ class TransformerBlock(nnx.Module):
30
+ """
31
+ A heavy weight module which serves as the stateful live blocks in serving
32
+
33
+ custom_module can be either a dense module (i.e., DenseFFW) or MoE.
34
+ """
35
+ pre_attention_norm: nnx.Module
36
+ pre_mlp_norm: nnx.Module
37
+ custom_module: Optional[nnx.Module] = None
38
+ attn: nnx.Module
39
+ use_attention_rope: bool = True
40
+ quant: Any | None = None
41
+
42
+ def __call__(
43
+ self, x_TD: jax.Array, is_prefill: bool, kv_cache: KVCache,
44
+ attention_metadata: AttentionMetadata
45
+ ) -> Tuple[KVCache, jax.Array]:
46
+ # Attn Block
47
+ attn_residual_TD = x_TD
48
+ x_TD = self.pre_attention_norm(x_TD)
49
+ new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
50
+ attention_metadata,
51
+ self.use_attention_rope)
52
+ attn_output_TD += attn_residual_TD
53
+
54
+ # FFW Block
55
+ ffw_residual_TD = attn_output_TD
56
+ normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
57
+ logits_TD = self.custom_module(normed_ffw_input_TD)
58
+ logits_TD += ffw_residual_TD
59
+ return new_cache, logits_TD
60
+
61
+
62
+ @dataclass(kw_only=True)
63
+ class SharedExpertsTransformerBlock(TransformerBlock):
64
+ """Create a modified TransformerBlock that sums MoE layer output with shared expert output.
65
+
66
+ Users can provide the FFW layer in two ways:
67
+ 1. Pass the module (either `MoE` or `DenseFFW`) to the `custom_module`
68
+ attribute.
69
+ 2. Specify the `moe_ffw` or `dense_ffw` attributes
70
+ (e.g., for passing quantized modules).
71
+
72
+ Attributes:
73
+ moe_ffw: Optional MoE layer.
74
+ dense_ffw: Optional DFF layer.
75
+ shared_experts: Optional shared experts module, used if MoE is enabled.
76
+
77
+ If an `MoE` layer is used (from either path), its output is summed
78
+ with the `shared_experts` module.
79
+ """
80
+
81
+ moe_ffw: Optional[MoE] = None
82
+ dense_ffw: Optional[DenseFFW] = None
83
+ shared_experts: Optional[DenseFFW] = None
84
+
85
+ def __call__(self, x_TD, is_prefill, kv_cache, attention_metadata):
86
+ # Attn Block
87
+ attn_residual_TD = x_TD
88
+ x_TD = self.pre_attention_norm(x_TD)
89
+ new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
90
+ attention_metadata,
91
+ self.use_attention_rope)
92
+ attn_output_TD += attn_residual_TD
93
+
94
+ # FFW Block
95
+ ffw_residual_TD = attn_output_TD
96
+ normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
97
+
98
+ if isinstance(self.custom_module, MoE):
99
+ moe_layer = self.custom_module
100
+ else:
101
+ moe_layer = self.moe_ffw
102
+
103
+ if isinstance(self.custom_module, DenseFFW):
104
+ dense_layer = self.custom_module
105
+ else:
106
+ dense_layer = self.dense_ffw
107
+
108
+ if moe_layer is not None:
109
+ logits_TD = moe_layer(normed_ffw_input_TD)
110
+ # Add the shared expert outputs to the MoE outputs.
111
+ shared_expert_output_TD = self.shared_experts(normed_ffw_input_TD)
112
+ logits_TD += shared_expert_output_TD
113
+ elif dense_layer is not None:
114
+ logits_TD = dense_layer(normed_ffw_input_TD)
115
+ else:
116
+ raise ValueError(
117
+ "Neither custom_module, moe_ffw nor dense_ffw attribute is set for this SharedExpertsTransformerBlock!"
118
+ )
119
+
120
+ logits_TD += ffw_residual_TD
121
+ return new_cache, logits_TD
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,221 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import functools
4
+ from typing import Optional, Tuple
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import torch
9
+ from jax.sharding import Mesh
10
+ from torchax.interop import jax_view, torch_view
11
+ from torchax.ops.mappings import t2j
12
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
13
+ AttentionLayer, AttentionType)
14
+
15
+ from tpu_inference import utils
16
+ from tpu_inference.layers.common.attention_interface import attention
17
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
18
+ from tpu_inference.layers.common.quantization import quantize_kv
19
+ from tpu_inference.logger import init_logger
20
+ from tpu_inference.models.vllm.vllm_model_wrapper_context import \
21
+ get_vllm_model_wrapper_context
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ class PallasAttentionBackend(AttentionBackend):
27
+
28
+ @staticmethod
29
+ def get_name() -> str:
30
+ return "PALLAS"
31
+
32
+ @staticmethod
33
+ def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
34
+ return PallasAttentionBackendImpl
35
+
36
+
37
+ class PallasAttentionBackendImpl(AttentionImpl):
38
+
39
+ def __init__(
40
+ self,
41
+ num_heads: int,
42
+ head_size: int,
43
+ scale: float,
44
+ num_kv_heads: int,
45
+ alibi_slopes: list[float] | None,
46
+ sliding_window: int | None,
47
+ kv_cache_dtype: str,
48
+ logits_soft_cap: float | None = None,
49
+ attn_type: AttentionType = AttentionType.DECODER,
50
+ kv_sharing_target_layer_name: str | None = None,
51
+ sinks: torch.Tensor | None = None,
52
+ ) -> None:
53
+ self.num_heads = num_heads
54
+ self.head_size = head_size
55
+ self.scale = float(scale)
56
+ self.num_kv_heads = num_kv_heads
57
+ self.sliding_window = sliding_window
58
+ self.logits_soft_cap = logits_soft_cap
59
+ self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
60
+
61
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
62
+ if alibi_slopes is not None:
63
+ raise NotImplementedError("Alibi slopes is not supported.")
64
+ self.kv_cache_quantized_dtype = None
65
+ if kv_cache_dtype != "auto":
66
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
67
+ kv_cache_dtype)
68
+
69
+ if attn_type != AttentionType.DECODER:
70
+ raise NotImplementedError("Encoder self-attention and "
71
+ "encoder/decoder cross-attention "
72
+ "are not implemented for "
73
+ "PallasAttentionBackendImpl")
74
+
75
+ self.sinks = sinks
76
+ if self.sinks is not None:
77
+ assert self.sinks.shape[0] == num_heads, (
78
+ "Sinks must have the same number of heads as the number of "
79
+ "heads in the layer")
80
+
81
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
82
+ #TODO (kyuyeunk): Shard the sinks along num_heads dim
83
+ if self.sinks is not None:
84
+ sinks = t2j(self.sinks, use_dlpack=False)
85
+ sinks = torch_view(sinks.astype(jnp.float32))
86
+ self.sinks = torch.nn.Parameter(sinks, requires_grad=False)
87
+
88
+ def forward(
89
+ self,
90
+ layer: AttentionLayer,
91
+ query: torch.Tensor,
92
+ key: torch.Tensor,
93
+ value: torch.Tensor,
94
+ kv_cache: torch.Tensor,
95
+ attn_metadata: AttentionMetadata,
96
+ output: Optional[torch.Tensor] = None,
97
+ output_scale: Optional[torch.Tensor] = None,
98
+ ) -> torch.Tensor:
99
+ if output_scale is not None:
100
+ raise NotImplementedError(
101
+ "fused output quantization is not yet supported for "
102
+ "PallasAttentionBackendImpl")
103
+
104
+ if kv_cache.numel():
105
+ raise RuntimeError(
106
+ "KV cache from vLLM Attention layer should be empty but has "
107
+ "the size of %s.", kv_cache.numel())
108
+
109
+ del kv_cache # Use kv_cache from vllm wrapper context values instead.
110
+
111
+ vllm_model_wrapper_context = get_vllm_model_wrapper_context()
112
+ kv_cache_index = vllm_model_wrapper_context.layer_name_to_kvcache_index[
113
+ layer.layer_name]
114
+ kv_cache = vllm_model_wrapper_context.kv_caches[kv_cache_index]
115
+
116
+ mesh = vllm_model_wrapper_context.mesh
117
+
118
+ query, key, value = jax_view(query), jax_view(key), jax_view(value)
119
+ q_scale = k_scale = v_scale = None
120
+ if self.kv_cache_quantized_dtype:
121
+ key, value = quantize_kv(self.kv_cache_quantized_dtype, key, value,
122
+ layer._k_scale_float,
123
+ layer._v_scale_float)
124
+ # TODO(kyuyeunk): Enable w8a8 when VREG spill issue is resolved.
125
+ # q_scale = layer._q_scale_float
126
+ k_scale = layer._k_scale_float
127
+ v_scale = layer._v_scale_float
128
+
129
+ sinks = jax_view(self.sinks)
130
+
131
+ new_kv_cache, outputs = _jax_attn_func(
132
+ kv_cache,
133
+ query,
134
+ key,
135
+ value,
136
+ sinks,
137
+ attn_metadata,
138
+ mesh,
139
+ self.scale,
140
+ self.head_size,
141
+ self.num_heads,
142
+ self.num_kv_heads,
143
+ q_scale,
144
+ k_scale,
145
+ v_scale,
146
+ self.sliding_window,
147
+ )
148
+ vllm_model_wrapper_context.kv_caches[kv_cache_index] = new_kv_cache
149
+
150
+ return torch_view(outputs)
151
+
152
+
153
+ @functools.partial(
154
+ jax.jit,
155
+ static_argnames=(
156
+ "mesh",
157
+ "scale",
158
+ "head_size",
159
+ "num_heads",
160
+ "num_kv_heads",
161
+ "q_scale",
162
+ "k_scale",
163
+ "v_scale",
164
+ "sliding_window",
165
+ ),
166
+ donate_argnames=("kv_cache"),
167
+ )
168
+ def _jax_attn_func(
169
+ kv_cache: jax.Array,
170
+ q: jax.Array,
171
+ k: jax.Array,
172
+ v: jax.Array,
173
+ sinks: jax.Array | None,
174
+ attention_metadata: AttentionMetadata,
175
+ mesh: Mesh,
176
+ scale: float,
177
+ head_size: int,
178
+ num_heads: int,
179
+ num_kv_heads: int,
180
+ q_scale: float | None = None,
181
+ k_scale: float | None = None,
182
+ v_scale: float | None = None,
183
+ sliding_window: int | None = None,
184
+ ) -> Tuple[jax.Array, jax.Array]:
185
+ del scale # Unused for now, as the attention function applies a default scale.
186
+
187
+ # Get shapes from vllm
188
+ q_len, q_compute_dim = q.shape
189
+ k_len, k_compute_dim = k.shape
190
+ assert k.shape == v.shape
191
+ assert q_compute_dim == head_size * num_heads
192
+ assert k_compute_dim == head_size * num_kv_heads
193
+
194
+ # Convert the shapes from vLLM's convetion to what the attention function expects
195
+ # bs, num_heads, q_len, head_size
196
+ q = q.reshape(q_len, num_heads, head_size)
197
+ # bs, num_kv_heads, k_len, head_size
198
+ k = k.reshape(k_len, num_kv_heads, head_size)
199
+ v = v.reshape(k_len, num_kv_heads, head_size)
200
+
201
+ new_kv_cache, outputs = attention(
202
+ kv_cache,
203
+ q,
204
+ k,
205
+ v,
206
+ attention_metadata,
207
+ mesh,
208
+ q_scale=q_scale,
209
+ k_scale=k_scale,
210
+ v_scale=v_scale,
211
+ sinks=sinks,
212
+ attention_chunk_size=sliding_window,
213
+ )
214
+
215
+ # Convert the shape back to vLLM's convention
216
+ assert outputs.shape[0] == q_len
217
+ assert outputs.shape[1] == num_heads
218
+ assert outputs.shape[2] == head_size
219
+ outputs = outputs.reshape(q_len, q_compute_dim)
220
+
221
+ return new_kv_cache, outputs