tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,107 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Optional, Tuple
3
+
4
+ # Flax and JAX sharding imports
5
+ import jax
6
+ from flax import nnx
7
+
8
+ from tpu_inference.layers.jax.attention.attention import (AttentionMetadata,
9
+ KVCache)
10
+ from tpu_inference.layers.jax.layers import DenseFFW
11
+ from tpu_inference.layers.jax.moe.moe import MoE
12
+
13
+
14
+ @dataclass(kw_only=True)
15
+ class TransformerBlock(nnx.Module):
16
+ """
17
+ A heavy weight module which serves as the stateful live blocks in serving
18
+
19
+ custom_module can be either a dense module (i.e., DenseFFW) or MoE.
20
+ """
21
+ pre_attention_norm: nnx.Module
22
+ pre_mlp_norm: nnx.Module
23
+ custom_module: Optional[nnx.Module] = None
24
+ attn: nnx.Module
25
+ use_attention_rope: bool = True
26
+ quant: Any | None = None
27
+
28
+ def __call__(
29
+ self, x_TD: jax.Array, is_prefill: bool, kv_cache: KVCache,
30
+ attention_metadata: AttentionMetadata
31
+ ) -> Tuple[KVCache, jax.Array]:
32
+ # Attn Block
33
+ attn_residual_TD = x_TD
34
+ x_TD = self.pre_attention_norm(x_TD)
35
+ new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
36
+ attention_metadata,
37
+ self.use_attention_rope)
38
+ attn_output_TD += attn_residual_TD
39
+
40
+ # FFW Block
41
+ ffw_residual_TD = attn_output_TD
42
+ normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
43
+ logits_TD = self.custom_module(normed_ffw_input_TD)
44
+ logits_TD += ffw_residual_TD
45
+ return new_cache, logits_TD
46
+
47
+
48
+ @dataclass(kw_only=True)
49
+ class SharedExpertsTransformerBlock(TransformerBlock):
50
+ """Create a modified TransformerBlock that sums MoE layer output with shared expert output.
51
+
52
+ Users can provide the FFW layer in two ways:
53
+ 1. Pass the module (either `MoE` or `DenseFFW`) to the `custom_module`
54
+ attribute.
55
+ 2. Specify the `moe_ffw` or `dense_ffw` attributes
56
+ (e.g., for passing quantized modules).
57
+
58
+ Attributes:
59
+ moe_ffw: Optional MoE layer.
60
+ dense_ffw: Optional DFF layer.
61
+ shared_experts: Optional shared experts module, used if MoE is enabled.
62
+
63
+ If an `MoE` layer is used (from either path), its output is summed
64
+ with the `shared_experts` module.
65
+ """
66
+
67
+ moe_ffw: Optional[MoE] = None
68
+ dense_ffw: Optional[DenseFFW] = None
69
+ shared_experts: Optional[DenseFFW] = None
70
+
71
+ def __call__(self, x_TD, is_prefill, kv_cache, attention_metadata):
72
+ # Attn Block
73
+ attn_residual_TD = x_TD
74
+ x_TD = self.pre_attention_norm(x_TD)
75
+ new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
76
+ attention_metadata,
77
+ self.use_attention_rope)
78
+ attn_output_TD += attn_residual_TD
79
+
80
+ # FFW Block
81
+ ffw_residual_TD = attn_output_TD
82
+ normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
83
+
84
+ if isinstance(self.custom_module, MoE):
85
+ moe_layer = self.custom_module
86
+ else:
87
+ moe_layer = self.moe_ffw
88
+
89
+ if isinstance(self.custom_module, DenseFFW):
90
+ dense_layer = self.custom_module
91
+ else:
92
+ dense_layer = self.dense_ffw
93
+
94
+ if moe_layer is not None:
95
+ logits_TD = moe_layer(normed_ffw_input_TD)
96
+ # Add the shared expert outputs to the MoE outputs.
97
+ shared_expert_output_TD = self.shared_experts(normed_ffw_input_TD)
98
+ logits_TD += shared_expert_output_TD
99
+ elif dense_layer is not None:
100
+ logits_TD = dense_layer(normed_ffw_input_TD)
101
+ else:
102
+ raise ValueError(
103
+ "Neither custom_module, moe_ffw nor dense_ffw attribute is set for this SharedExpertsTransformerBlock!"
104
+ )
105
+
106
+ logits_TD += ffw_residual_TD
107
+ return new_cache, logits_TD
File without changes
@@ -0,0 +1,221 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import functools
4
+ from typing import Optional, Tuple
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import torch
9
+ from jax.sharding import Mesh
10
+ from torchax.interop import jax_view, torch_view
11
+ from torchax.ops.mappings import t2j
12
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
13
+ AttentionLayer, AttentionType)
14
+
15
+ from tpu_inference import utils
16
+ from tpu_inference.layers.common.attention_interface import attention
17
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
18
+ from tpu_inference.logger import init_logger
19
+ from tpu_inference.models.vllm.vllm_model_wrapper_context import \
20
+ get_vllm_model_wrapper_context
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ class PallasAttentionBackend(AttentionBackend):
26
+
27
+ @staticmethod
28
+ def get_name() -> str:
29
+ return "PALLAS"
30
+
31
+ @staticmethod
32
+ def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
33
+ return PallasAttentionBackendImpl
34
+
35
+
36
+ class PallasAttentionBackendImpl(AttentionImpl):
37
+
38
+ def __init__(
39
+ self,
40
+ num_heads: int,
41
+ head_size: int,
42
+ scale: float,
43
+ num_kv_heads: int,
44
+ alibi_slopes: list[float] | None,
45
+ sliding_window: int | None,
46
+ kv_cache_dtype: str,
47
+ logits_soft_cap: float | None = None,
48
+ attn_type: AttentionType = AttentionType.DECODER,
49
+ kv_sharing_target_layer_name: str | None = None,
50
+ sinks: torch.Tensor | None = None,
51
+ ) -> None:
52
+ self.num_heads = num_heads
53
+ self.head_size = head_size
54
+ self.scale = float(scale)
55
+ self.num_kv_heads = num_kv_heads
56
+ self.sliding_window = sliding_window
57
+ self.logits_soft_cap = logits_soft_cap
58
+ self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
59
+
60
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
61
+ if alibi_slopes is not None:
62
+ raise NotImplementedError("Alibi slopes is not supported.")
63
+ self.kv_cache_quantized_dtype = None
64
+ if kv_cache_dtype != "auto":
65
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
66
+ kv_cache_dtype)
67
+
68
+ if attn_type != AttentionType.DECODER:
69
+ raise NotImplementedError("Encoder self-attention and "
70
+ "encoder/decoder cross-attention "
71
+ "are not implemented for "
72
+ "PallasAttentionBackendImpl")
73
+
74
+ self.sinks = sinks
75
+ if self.sinks is not None:
76
+ assert self.sinks.shape[0] == num_heads, (
77
+ "Sinks must have the same number of heads as the number of "
78
+ "heads in the layer")
79
+
80
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
81
+ #TODO (kyuyeunk): Shard the sinks along num_heads dim
82
+ if self.sinks is not None:
83
+ sinks = t2j(self.sinks, use_dlpack=False)
84
+ sinks = torch_view(sinks.astype(jnp.float32))
85
+ self.sinks = torch.nn.Parameter(sinks, requires_grad=False)
86
+
87
+ def forward(
88
+ self,
89
+ layer: AttentionLayer,
90
+ query: torch.Tensor,
91
+ key: torch.Tensor,
92
+ value: torch.Tensor,
93
+ kv_cache: torch.Tensor,
94
+ attn_metadata: AttentionMetadata,
95
+ output: Optional[torch.Tensor] = None,
96
+ output_scale: Optional[torch.Tensor] = None,
97
+ ) -> torch.Tensor:
98
+ if output_scale is not None:
99
+ raise NotImplementedError(
100
+ "fused output quantization is not yet supported for "
101
+ "PallasAttentionBackendImpl")
102
+
103
+ if kv_cache.numel():
104
+ raise RuntimeError(
105
+ "KV cache from vLLM Attention layer should be empty but has "
106
+ "the size of %s.", kv_cache.numel())
107
+
108
+ del kv_cache # Use kv_cache from vllm wrapper context values instead.
109
+
110
+ vllm_model_wrapper_context = get_vllm_model_wrapper_context()
111
+ kv_cache_index = vllm_model_wrapper_context.layer_name_to_kvcache_index[
112
+ layer.layer_name]
113
+ kv_cache = vllm_model_wrapper_context.kv_caches[kv_cache_index]
114
+
115
+ mesh = vllm_model_wrapper_context.mesh
116
+
117
+ query, key, value = jax_view(query), jax_view(key), jax_view(value)
118
+ q_scale = k_scale = v_scale = None
119
+ if self.kv_cache_quantized_dtype:
120
+ key, value = utils.quantize_kv(key, value,
121
+ self.kv_cache_quantized_dtype,
122
+ layer._k_scale_float,
123
+ layer._v_scale_float)
124
+ # TODO(kyuyeunk): Enable w8a8 when VREG spill issue is resolved.
125
+ # q_scale = layer._q_scale_float
126
+ k_scale = layer._k_scale_float
127
+ v_scale = layer._v_scale_float
128
+
129
+ sinks = jax_view(self.sinks)
130
+
131
+ new_kv_cache, outputs = _jax_attn_func(
132
+ kv_cache,
133
+ query,
134
+ key,
135
+ value,
136
+ sinks,
137
+ attn_metadata,
138
+ mesh,
139
+ self.scale,
140
+ self.head_size,
141
+ self.num_heads,
142
+ self.num_kv_heads,
143
+ q_scale,
144
+ k_scale,
145
+ v_scale,
146
+ self.sliding_window,
147
+ )
148
+ vllm_model_wrapper_context.kv_caches[kv_cache_index] = new_kv_cache
149
+
150
+ return torch_view(outputs)
151
+
152
+
153
+ @functools.partial(
154
+ jax.jit,
155
+ static_argnames=(
156
+ "mesh",
157
+ "scale",
158
+ "head_size",
159
+ "num_heads",
160
+ "num_kv_heads",
161
+ "q_scale",
162
+ "k_scale",
163
+ "v_scale",
164
+ "sliding_window",
165
+ ),
166
+ donate_argnames=("kv_cache"),
167
+ )
168
+ def _jax_attn_func(
169
+ kv_cache: jax.Array,
170
+ q: jax.Array,
171
+ k: jax.Array,
172
+ v: jax.Array,
173
+ sinks: jax.Array | None,
174
+ attention_metadata: AttentionMetadata,
175
+ mesh: Mesh,
176
+ scale: float,
177
+ head_size: int,
178
+ num_heads: int,
179
+ num_kv_heads: int,
180
+ q_scale: float | None = None,
181
+ k_scale: float | None = None,
182
+ v_scale: float | None = None,
183
+ sliding_window: int | None = None,
184
+ ) -> Tuple[jax.Array, jax.Array]:
185
+ del scale # Unused for now, as the attention function applies a default scale.
186
+
187
+ # Get shapes from vllm
188
+ q_len, q_compute_dim = q.shape
189
+ k_len, k_compute_dim = k.shape
190
+ assert k.shape == v.shape
191
+ assert q_compute_dim == head_size * num_heads
192
+ assert k_compute_dim == head_size * num_kv_heads
193
+
194
+ # Convert the shapes from vLLM's convetion to what the attention function expects
195
+ # bs, num_heads, q_len, head_size
196
+ q = q.reshape(q_len, num_heads, head_size)
197
+ # bs, num_kv_heads, k_len, head_size
198
+ k = k.reshape(k_len, num_kv_heads, head_size)
199
+ v = v.reshape(k_len, num_kv_heads, head_size)
200
+
201
+ new_kv_cache, outputs = attention(
202
+ kv_cache,
203
+ q,
204
+ k,
205
+ v,
206
+ attention_metadata,
207
+ mesh,
208
+ q_scale=q_scale,
209
+ k_scale=k_scale,
210
+ v_scale=v_scale,
211
+ sinks=sinks,
212
+ attention_chunk_size=sliding_window,
213
+ )
214
+
215
+ # Convert the shape back to vLLM's convention
216
+ assert outputs.shape[0] == q_len
217
+ assert outputs.shape[1] == num_heads
218
+ assert outputs.shape[2] == head_size
219
+ outputs = outputs.reshape(q_len, q_compute_dim)
220
+
221
+ return new_kv_cache, outputs