tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import math
|
|
2
16
|
from dataclasses import InitVar, dataclass
|
|
3
17
|
from typing import Any, Tuple
|
|
@@ -6,14 +20,18 @@ import jax
|
|
|
6
20
|
import jax.numpy as jnp
|
|
7
21
|
from flax import nnx
|
|
8
22
|
from flax.typing import Sharding
|
|
9
|
-
from jax.experimental import shard_map
|
|
10
23
|
from jax.sharding import Mesh
|
|
11
24
|
from jax.sharding import PartitionSpec as P
|
|
12
25
|
|
|
13
26
|
from tpu_inference import utils
|
|
27
|
+
from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention
|
|
14
28
|
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
|
|
15
29
|
ragged_paged_attention
|
|
30
|
+
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
|
|
31
|
+
get_tuned_block_sizes
|
|
16
32
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
33
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
34
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
17
35
|
from tpu_inference.layers.jax.base import create_param
|
|
18
36
|
from tpu_inference.layers.jax.layers import RMSNorm
|
|
19
37
|
from tpu_inference.layers.jax.rope import DeepseekScalingRotaryEmbedding
|
|
@@ -48,8 +66,9 @@ class MLA(nnx.Module):
|
|
|
48
66
|
rms_norm_eps: float
|
|
49
67
|
|
|
50
68
|
# Sharding attributes
|
|
51
|
-
|
|
69
|
+
rd_sharding: Sharding = ()
|
|
52
70
|
q_da_sharding: Sharding = ()
|
|
71
|
+
ap_sharding: Sharding = ()
|
|
53
72
|
anh_sharding: Sharding = ()
|
|
54
73
|
kv_da_sharding: Sharding = ()
|
|
55
74
|
|
|
@@ -66,6 +85,7 @@ class MLA(nnx.Module):
|
|
|
66
85
|
rope_input_ordering: str = "split"
|
|
67
86
|
quant: Any | None = None
|
|
68
87
|
rope_mscale_all_dim: float = 1.0
|
|
88
|
+
use_mla_kernel: bool = False
|
|
69
89
|
|
|
70
90
|
rngs: InitVar[nnx.Rngs]
|
|
71
91
|
|
|
@@ -77,10 +97,10 @@ class MLA(nnx.Module):
|
|
|
77
97
|
self.N = self.num_attention_heads
|
|
78
98
|
self.K = self.num_key_value_heads
|
|
79
99
|
self.D = self.hidden_size
|
|
80
|
-
|
|
81
100
|
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
|
82
101
|
|
|
83
|
-
|
|
102
|
+
if not self.use_mla_kernel:
|
|
103
|
+
assert self.N == self.K, "N and K must be equal for MLA"
|
|
84
104
|
|
|
85
105
|
if self.rope_scaling["factor"] <= 1.0:
|
|
86
106
|
yarn_mscale = 1.0
|
|
@@ -108,10 +128,10 @@ class MLA(nnx.Module):
|
|
|
108
128
|
self.q_da_sharding,
|
|
109
129
|
self.dtype,
|
|
110
130
|
random_init=self.random_init)
|
|
111
|
-
self.
|
|
131
|
+
self.kernel_q_up_proj_AP = create_param(
|
|
112
132
|
rngs,
|
|
113
|
-
(self.q_lora_rank, self.N
|
|
114
|
-
self.
|
|
133
|
+
(self.q_lora_rank, self.N * self.qk_head_dim),
|
|
134
|
+
self.ap_sharding,
|
|
115
135
|
self.dtype,
|
|
116
136
|
random_init=self.random_init,
|
|
117
137
|
)
|
|
@@ -122,17 +142,38 @@ class MLA(nnx.Module):
|
|
|
122
142
|
self.dtype,
|
|
123
143
|
random_init=self.random_init,
|
|
124
144
|
)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
self.
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
145
|
+
# NOTE (jacobplatin): we are keeping these variables as 3D because
|
|
146
|
+
# we would need to reshape them before the below projection,
|
|
147
|
+
# which caused issues as Qwix wasn't quantizing it correctly
|
|
148
|
+
# on the abstract pass
|
|
149
|
+
if self.use_mla_kernel:
|
|
150
|
+
self.kernel_k_up_proj_ANH = create_param(
|
|
151
|
+
rngs,
|
|
152
|
+
(self.kv_lora_rank, self.N, self.qk_nope_head_dim),
|
|
153
|
+
self.anh_sharding,
|
|
154
|
+
self.dtype,
|
|
155
|
+
random_init=self.random_init,
|
|
156
|
+
)
|
|
157
|
+
self.kernel_v_up_proj_ANH = create_param(
|
|
158
|
+
rngs,
|
|
159
|
+
(self.kv_lora_rank, self.N, self.v_head_dim),
|
|
160
|
+
self.anh_sharding,
|
|
161
|
+
self.dtype,
|
|
162
|
+
random_init=self.random_init,
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
self.kernel_kv_up_proj_AL = create_param(
|
|
166
|
+
rngs,
|
|
167
|
+
(self.kv_lora_rank, self.N *
|
|
168
|
+
(self.qk_nope_head_dim + self.v_head_dim)),
|
|
169
|
+
self.
|
|
170
|
+
ap_sharding, # NOTE: we use the same sharding for kv_up_proj_AL and kernel_q_up_proj_AP
|
|
171
|
+
self.dtype,
|
|
172
|
+
random_init=self.random_init,
|
|
173
|
+
)
|
|
174
|
+
self.kernel_o_proj_RD = create_param(
|
|
175
|
+
rngs, (self.N * self.v_head_dim, self.D),
|
|
176
|
+
self.rd_sharding,
|
|
136
177
|
self.dtype,
|
|
137
178
|
random_init=self.random_init)
|
|
138
179
|
self.q_rms_norm = RMSNorm(
|
|
@@ -188,17 +229,24 @@ class MLA(nnx.Module):
|
|
|
188
229
|
q_TA = jnp.einsum("TD,DA -> TA", x_q_TD,
|
|
189
230
|
self.kernel_q_down_proj_DA.value)
|
|
190
231
|
q_TA = self.q_rms_norm(q_TA)
|
|
191
|
-
# Query up projection.
|
|
192
|
-
|
|
193
|
-
|
|
232
|
+
# Query up projection, then reshape to TNH.
|
|
233
|
+
q_TP = jnp.einsum("TA,AP -> TP", q_TA,
|
|
234
|
+
self.kernel_q_up_proj_AP.value)
|
|
235
|
+
q_TNH = q_TP.reshape(q_TA.shape[0], self.N, self.qk_head_dim)
|
|
194
236
|
# Split the query into nope and rope.
|
|
195
237
|
q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
|
|
196
238
|
q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
|
|
197
239
|
q_rope_TNH = self.rope.apply_rope(md.input_positions, q_rope_TNH)
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
240
|
+
if self.use_mla_kernel:
|
|
241
|
+
# Absorb the k up-projection matrix into q
|
|
242
|
+
q_TNA = jnp.einsum("TNH,ANH -> TNA", q_nope_TNH,
|
|
243
|
+
self.kernel_k_up_proj_ANH.value)
|
|
244
|
+
q_TNA = nnx.with_sharding_constraint(q_TNA, self.query_tnh)
|
|
245
|
+
else:
|
|
246
|
+
# Concatenate the nope and rope queries.
|
|
247
|
+
q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
|
|
248
|
+
# Multiply the query by scaling factor
|
|
249
|
+
q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
|
|
202
250
|
|
|
203
251
|
with jax.named_scope("kv_proj"):
|
|
204
252
|
# KV down projection.
|
|
@@ -209,21 +257,30 @@ class MLA(nnx.Module):
|
|
|
209
257
|
# Reshape k_rope_BSH to include head dimension for RoPE application
|
|
210
258
|
k_rope_SNH = k_rope_SH[..., None, :]
|
|
211
259
|
k_rope_SNH = self.rope.apply_rope(md.input_positions, k_rope_SNH)
|
|
212
|
-
k_rope_SNH
|
|
213
|
-
|
|
214
|
-
|
|
260
|
+
assert k_rope_SNH.shape[1] == 1
|
|
261
|
+
k_rope_SH = k_rope_SNH[:, 0, :]
|
|
262
|
+
|
|
215
263
|
kv_SA = kv_SA[..., :self.kv_lora_rank]
|
|
216
264
|
kv_SA = self.kv_rms_norm(kv_SA)
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
265
|
+
kv_SA = nnx.with_sharding_constraint(kv_SA, self.keyvalue_skh)
|
|
266
|
+
|
|
267
|
+
if not self.use_mla_kernel:
|
|
268
|
+
k_rope_SNH = jnp.broadcast_to(
|
|
269
|
+
k_rope_SNH,
|
|
270
|
+
(k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
|
|
271
|
+
# KV up projection, then reshape to SN(Hk+Hv).
|
|
272
|
+
kv_SL = jnp.einsum("SA,AL -> SL", kv_SA,
|
|
273
|
+
self.kernel_kv_up_proj_AL.value)
|
|
274
|
+
kv_nope_SNH = kv_SL.reshape(
|
|
275
|
+
kv_SA.shape[0], self.N,
|
|
276
|
+
self.qk_nope_head_dim + self.v_head_dim)
|
|
277
|
+
# Split the latent kv vector into k nope vector and v vector.
|
|
278
|
+
k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
|
|
279
|
+
v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
|
|
280
|
+
# Concatenate the key vector.
|
|
281
|
+
k_SNH = jnp.concatenate([k_nope_SNH, k_rope_SNH], axis=-1)
|
|
282
|
+
k_SNH = nnx.with_sharding_constraint(k_SNH, self.keyvalue_skh)
|
|
283
|
+
v_SNH = nnx.with_sharding_constraint(v_SNH, self.keyvalue_skh)
|
|
227
284
|
|
|
228
285
|
with jax.named_scope("attn_op"):
|
|
229
286
|
# TODO(wenxindongwork): K and V have different head dimension,
|
|
@@ -234,44 +291,67 @@ class MLA(nnx.Module):
|
|
|
234
291
|
# q, k, v head dimension to be multiple of 128. For now, we will
|
|
235
292
|
# pad the q, k, v dimension to multiple of 128.
|
|
236
293
|
# We should update the MLA kv cache implementation in the future.
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
294
|
+
if not self.use_mla_kernel: # MLA kernel handles padding
|
|
295
|
+
multiple_of_128 = ((self.qk_head_dim - 1) // 128 + 1) * 128
|
|
296
|
+
q_TNH = jnp.pad(q_TNH,
|
|
297
|
+
((0, 0), (0, 0),
|
|
298
|
+
(0, multiple_of_128 - self.qk_head_dim)))
|
|
299
|
+
k_SNH = jnp.pad(k_SNH,
|
|
300
|
+
((0, 0), (0, 0),
|
|
301
|
+
(0, multiple_of_128 - self.qk_head_dim)))
|
|
302
|
+
v_SNH = jnp.pad(v_SNH,
|
|
303
|
+
((0, 0), (0, 0),
|
|
304
|
+
(0, multiple_of_128 - self.v_head_dim)))
|
|
305
|
+
|
|
244
306
|
q_scale = k_scale = v_scale = None
|
|
245
|
-
if self.kv_cache_quantized_dtype:
|
|
246
|
-
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
247
|
-
# q_scale = self._q_scale
|
|
248
|
-
k_scale = self._k_scale
|
|
249
|
-
v_scale = self._v_scale
|
|
250
|
-
k_SNH, v_SNH = utils.quantize_kv(k_SNH, v_SNH,
|
|
251
|
-
self.kv_cache_quantized_dtype,
|
|
252
|
-
k_scale, v_scale)
|
|
253
|
-
new_kv_cache, outputs_TNH = self.attention(
|
|
254
|
-
is_prefill,
|
|
255
|
-
kv_cache,
|
|
256
|
-
q_TNH,
|
|
257
|
-
k_SNH,
|
|
258
|
-
v_SNH,
|
|
259
|
-
attention_metadata,
|
|
260
|
-
self.mesh,
|
|
261
|
-
q_scale,
|
|
262
|
-
k_scale,
|
|
263
|
-
v_scale,
|
|
264
|
-
)
|
|
265
|
-
# TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
|
|
266
|
-
# We shall add the MLA kv cache implementation in the future.
|
|
267
|
-
outputs_TNH = outputs_TNH[..., :self.v_head_dim]
|
|
268
307
|
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
308
|
+
# TODO(gpolovets): MLA does not currently support quantized KV!
|
|
309
|
+
if not self.use_mla_kernel:
|
|
310
|
+
if self.kv_cache_quantized_dtype:
|
|
311
|
+
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
312
|
+
k_scale = self._k_scale
|
|
313
|
+
v_scale = self._v_scale
|
|
314
|
+
k_SNH, v_SNH = quantize_kv(self.kv_cache_quantized_dtype,
|
|
315
|
+
k_SNH, v_SNH, k_scale, v_scale)
|
|
316
|
+
|
|
317
|
+
new_kv_cache, outputs_TNH = self.attention(
|
|
318
|
+
is_prefill,
|
|
319
|
+
kv_cache,
|
|
320
|
+
q_TNH,
|
|
321
|
+
k_SNH,
|
|
322
|
+
v_SNH,
|
|
323
|
+
attention_metadata,
|
|
324
|
+
self.mesh,
|
|
325
|
+
q_scale,
|
|
326
|
+
k_scale,
|
|
327
|
+
v_scale,
|
|
328
|
+
)
|
|
329
|
+
# TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
|
|
330
|
+
# We shall add the MLA kv cache implementation in the future.
|
|
331
|
+
outputs_TNH = outputs_TNH[..., :self.v_head_dim]
|
|
332
|
+
|
|
333
|
+
else:
|
|
334
|
+
new_kv_cache, outputs_TNA = self.mla_attention(
|
|
335
|
+
kv_cache,
|
|
336
|
+
q_TNA,
|
|
337
|
+
q_rope_TNH,
|
|
338
|
+
kv_SA,
|
|
339
|
+
k_rope_SH,
|
|
340
|
+
attention_metadata,
|
|
341
|
+
self.mesh,
|
|
342
|
+
)
|
|
343
|
+
outputs_TNH = jnp.einsum("TNA,ANH -> TNH", outputs_TNA,
|
|
344
|
+
self.kernel_v_up_proj_ANH.value)
|
|
345
|
+
|
|
346
|
+
with jax.named_scope("o_proj"):
|
|
347
|
+
outputs_TNH = nnx.with_sharding_constraint(
|
|
348
|
+
outputs_TNH, self.activation_attention_out_td)
|
|
349
|
+
outputs_TR = outputs_TNH.reshape(outputs_TNH.shape[0],
|
|
350
|
+
self.N * self.v_head_dim)
|
|
351
|
+
o_TD = jnp.einsum("TR,RD -> TD", outputs_TR,
|
|
352
|
+
self.kernel_o_proj_RD.value)
|
|
353
|
+
|
|
354
|
+
return new_kv_cache, o_TD
|
|
275
355
|
|
|
276
356
|
def attention(
|
|
277
357
|
self,
|
|
@@ -326,21 +406,22 @@ class MLA(nnx.Module):
|
|
|
326
406
|
out_specs = (self.attn_o_tnh, P(None, None, "model"))
|
|
327
407
|
|
|
328
408
|
def _ragged_paged_attention(*args):
|
|
329
|
-
|
|
409
|
+
outputs = ragged_paged_attention(
|
|
330
410
|
*args,
|
|
331
411
|
sm_scale=self.scale,
|
|
332
412
|
q_scale=q_scale,
|
|
333
413
|
k_scale=k_scale,
|
|
334
414
|
v_scale=v_scale,
|
|
335
415
|
)
|
|
416
|
+
return outputs
|
|
336
417
|
|
|
337
418
|
output_TNH, kv_cache = jax.jit(
|
|
338
|
-
|
|
419
|
+
jax.shard_map(
|
|
339
420
|
_ragged_paged_attention,
|
|
340
421
|
mesh=mesh,
|
|
341
422
|
in_specs=in_specs,
|
|
342
423
|
out_specs=out_specs,
|
|
343
|
-
|
|
424
|
+
check_vma=False,
|
|
344
425
|
))(
|
|
345
426
|
q_TNH,
|
|
346
427
|
k_SKH,
|
|
@@ -352,3 +433,115 @@ class MLA(nnx.Module):
|
|
|
352
433
|
md.request_distribution,
|
|
353
434
|
)
|
|
354
435
|
return kv_cache, output_TNH
|
|
436
|
+
|
|
437
|
+
def mla_attention(
|
|
438
|
+
self,
|
|
439
|
+
kv_cache: KVCache,
|
|
440
|
+
q_TNA: jax.Array,
|
|
441
|
+
q_rope_TNH: jax.Array,
|
|
442
|
+
k_SA: jax.Array,
|
|
443
|
+
k_rope_SH: jax.Array,
|
|
444
|
+
attention_metadata: AttentionMetadata,
|
|
445
|
+
mesh: Mesh,
|
|
446
|
+
) -> Tuple[KVCache, jax.Array]:
|
|
447
|
+
"""Performs scaled dot-product attention and updates the KV cache.
|
|
448
|
+
|
|
449
|
+
This function handles the core attention logic, which varies between
|
|
450
|
+
prefill and generation modes. In prefill, it computes self-attention
|
|
451
|
+
over the input sequence with a causal mask. In generation, it attends
|
|
452
|
+
to the full history of keys and values stored in the cache.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
kv_cache: The key-value cache to be updated and used.
|
|
456
|
+
q_TNA: Query tensor of shape `(query_seq, num_attention_heads, lkv_dim)`.
|
|
457
|
+
q_rope_TNH: Query rope tensor of shape `(query_seq, num_attention_heads, rope_dim)`.
|
|
458
|
+
k_SA: Key tensor of shape `(kv_seq, lkv_dim)`.
|
|
459
|
+
k_rope_SH: Key rope tensor of shape `(kv_seq, rope_dim)`.
|
|
460
|
+
attention_metadata: Metadata containing sequence lengths.
|
|
461
|
+
mesh: The JAX device mesh (unused in this specific function but
|
|
462
|
+
kept for potential future use or API consistency).
|
|
463
|
+
q_scale: Quantization scale for q.
|
|
464
|
+
k_scale: Quantization scale for k.
|
|
465
|
+
v_scale: Quantization scale for v.
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
A tuple containing:
|
|
469
|
+
- The updated KV cache.
|
|
470
|
+
- The attention output tensor of shape
|
|
471
|
+
`(seq, num_q_heads, head_dim)`.
|
|
472
|
+
"""
|
|
473
|
+
md = attention_metadata
|
|
474
|
+
in_specs = (
|
|
475
|
+
self.query_tnh, # q
|
|
476
|
+
self.query_tnh, # q_rope
|
|
477
|
+
self.keyvalue_skh, # k
|
|
478
|
+
self.keyvalue_skh, # k_rope
|
|
479
|
+
P(ShardingAxisName.MLP_TENSOR), # kv_cache
|
|
480
|
+
P(ShardingAxisName.ATTN_DATA), # md.seq_lens: Replicated
|
|
481
|
+
P(ShardingAxisName.ATTN_DATA), # page_indices_flat: Replicated
|
|
482
|
+
P(ShardingAxisName.ATTN_DATA), # query_start_loc: Replicated
|
|
483
|
+
P(ShardingAxisName.ATTN_DATA), # distribution: Replicated
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
out_specs = (self.attn_o_tnh, P(ShardingAxisName.MLP_TENSOR))
|
|
487
|
+
|
|
488
|
+
def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args):
|
|
489
|
+
|
|
490
|
+
def _initialize_block_sizes():
|
|
491
|
+
# Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes)
|
|
492
|
+
# Referring to get_tuned_block_sizes() in kernels/ragged_paged_attention/v3/tuned_block_sizes.py: 'TPU v7'/128/'q_bfloat16_kv_bfloat16/q_head-128_kv_head-1_head-128'/4096
|
|
493
|
+
max_num_tokens = q.shape[0]
|
|
494
|
+
max_num_seqs = md.seq_lens.shape[0]
|
|
495
|
+
num_page_indices = md.block_tables.shape[0]
|
|
496
|
+
assert num_page_indices % max_num_seqs == 0
|
|
497
|
+
pages_per_seq = num_page_indices // max_num_seqs
|
|
498
|
+
# num_kv_pages_per_block = min(pages_per_seq, 16)
|
|
499
|
+
bkv_p, bq_sz = get_tuned_block_sizes(
|
|
500
|
+
q.dtype,
|
|
501
|
+
kv_cache.dtype,
|
|
502
|
+
self.num_attention_heads,
|
|
503
|
+
1,
|
|
504
|
+
self.qk_nope_head_dim,
|
|
505
|
+
kv_cache.shape[1], # page size
|
|
506
|
+
max_num_tokens,
|
|
507
|
+
pages_per_seq,
|
|
508
|
+
)
|
|
509
|
+
num_kv_pages_per_block = min(min(pages_per_seq, bkv_p), 4)
|
|
510
|
+
num_queries_per_block = min(min(max_num_tokens, bq_sz),
|
|
511
|
+
4) # OOMS at 8
|
|
512
|
+
return num_kv_pages_per_block, num_queries_per_block
|
|
513
|
+
|
|
514
|
+
num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes(
|
|
515
|
+
)
|
|
516
|
+
output, kv_cache = mla_ragged_paged_attention(
|
|
517
|
+
q,
|
|
518
|
+
q_rope,
|
|
519
|
+
k,
|
|
520
|
+
k_rope,
|
|
521
|
+
kv_cache,
|
|
522
|
+
*args,
|
|
523
|
+
sm_scale=self.scale,
|
|
524
|
+
num_kv_pages_per_block=num_kv_pages_per_block,
|
|
525
|
+
num_queries_per_block=num_queries_per_block)
|
|
526
|
+
|
|
527
|
+
return kv_cache, output
|
|
528
|
+
|
|
529
|
+
kv_cache, output_TNH = jax.jit(
|
|
530
|
+
jax.shard_map(
|
|
531
|
+
_mla_ragged_paged_attention,
|
|
532
|
+
mesh=mesh,
|
|
533
|
+
in_specs=in_specs,
|
|
534
|
+
out_specs=out_specs,
|
|
535
|
+
check_vma=False,
|
|
536
|
+
), )(
|
|
537
|
+
q_TNA,
|
|
538
|
+
q_rope_TNH,
|
|
539
|
+
k_SA,
|
|
540
|
+
k_rope_SH,
|
|
541
|
+
kv_cache,
|
|
542
|
+
md.seq_lens,
|
|
543
|
+
md.block_tables,
|
|
544
|
+
md.query_start_loc,
|
|
545
|
+
md.request_distribution,
|
|
546
|
+
)
|
|
547
|
+
return kv_cache, output_TNH
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import InitVar, dataclass
|
|
2
16
|
from typing import Tuple
|
|
3
17
|
|
|
@@ -5,7 +19,6 @@ import jax
|
|
|
5
19
|
import jax.numpy as jnp
|
|
6
20
|
from flax import nnx
|
|
7
21
|
from flax.typing import Sharding
|
|
8
|
-
from jax.experimental import shard_map
|
|
9
22
|
from jax.sharding import Mesh
|
|
10
23
|
from jax.sharding import PartitionSpec as P
|
|
11
24
|
|
|
@@ -13,6 +26,7 @@ from tpu_inference import utils
|
|
|
13
26
|
from tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 import \
|
|
14
27
|
ragged_paged_attention_hd64
|
|
15
28
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
29
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
16
30
|
from tpu_inference.layers.jax.base import create_param
|
|
17
31
|
from tpu_inference.layers.jax.rope import GptOssRotaryEmbedding
|
|
18
32
|
|
|
@@ -158,17 +172,17 @@ class GptOssAttention(nnx.Module):
|
|
|
158
172
|
) -> Tuple[KVCache, jax.Array]:
|
|
159
173
|
"""Performs scaled dot-product attention by calling the ragged_paged_attention kernel."""
|
|
160
174
|
md = attention_metadata
|
|
161
|
-
kv_cache_spec = P(
|
|
175
|
+
kv_cache_spec = P("data", None, "model")
|
|
162
176
|
|
|
163
177
|
in_specs = (
|
|
164
178
|
self.query_tnh, # q
|
|
165
179
|
self.keyvalue_skh, # k
|
|
166
180
|
self.keyvalue_skh, # v
|
|
167
181
|
kv_cache_spec, # kv_cache
|
|
168
|
-
P(), # md.seq_lens
|
|
169
|
-
P(), # page_indices_flat
|
|
170
|
-
P(), # query_start_loc
|
|
171
|
-
P(), # distribution
|
|
182
|
+
P("data"), # md.seq_lens
|
|
183
|
+
P("data"), # page_indices_flat
|
|
184
|
+
P("data"), # query_start_loc
|
|
185
|
+
P("data"), # distribution
|
|
172
186
|
P(('model')), # sinks
|
|
173
187
|
)
|
|
174
188
|
out_specs = (self.attn_o_tnh, kv_cache_spec)
|
|
@@ -185,12 +199,12 @@ class GptOssAttention(nnx.Module):
|
|
|
185
199
|
)
|
|
186
200
|
|
|
187
201
|
output_TNH, kv_cache = jax.jit(
|
|
188
|
-
|
|
202
|
+
jax.shard_map(
|
|
189
203
|
_ragged_paged_attention_wrapper,
|
|
190
204
|
mesh=mesh,
|
|
191
205
|
in_specs=in_specs,
|
|
192
206
|
out_specs=out_specs,
|
|
193
|
-
|
|
207
|
+
check_vma=False,
|
|
194
208
|
))(
|
|
195
209
|
q_TNH,
|
|
196
210
|
k_SKH,
|
|
@@ -235,9 +249,8 @@ class GptOssAttention(nnx.Module):
|
|
|
235
249
|
# q_scale = self._q_scale
|
|
236
250
|
k_scale = self._k_scale
|
|
237
251
|
v_scale = self._v_scale
|
|
238
|
-
k_TKH, v_TKH =
|
|
239
|
-
|
|
240
|
-
k_scale, v_scale)
|
|
252
|
+
k_TKH, v_TKH = quantize_kv(self.kv_cache_quantized_dtype, k_TKH,
|
|
253
|
+
v_TKH, k_scale, v_scale)
|
|
241
254
|
|
|
242
255
|
with jax.named_scope("attn_op"):
|
|
243
256
|
new_kv_cache, attn_out_TNH = self.attention(
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import dataclass
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -5,8 +19,8 @@ import jax.numpy as jnp
|
|
|
5
19
|
from flax import nnx
|
|
6
20
|
from jax.sharding import Sharding
|
|
7
21
|
|
|
8
|
-
from tpu_inference import utils
|
|
9
22
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
23
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
10
24
|
from tpu_inference.layers.jax.attention.attention import Attention, KVCache
|
|
11
25
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
12
26
|
from tpu_inference.logger import init_logger
|
|
@@ -114,9 +128,8 @@ class Llama4Attention(Attention):
|
|
|
114
128
|
# q_scale = self._q_scale
|
|
115
129
|
k_scale = self._k_scale
|
|
116
130
|
v_scale = self._v_scale
|
|
117
|
-
k_SKH, v_SKH =
|
|
118
|
-
|
|
119
|
-
k_scale, v_scale)
|
|
131
|
+
k_SKH, v_SKH = quantize_kv(self.kv_cache_quantized_dtype, k_SKH,
|
|
132
|
+
v_SKH, k_scale, v_scale)
|
|
120
133
|
|
|
121
134
|
with jax.named_scope("attn_op"):
|
|
122
135
|
new_kv_cache, outputs_TNH = self.attention(
|
tpu_inference/layers/jax/base.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import dataclasses
|
|
2
16
|
from dataclasses import dataclass, fields
|
|
3
17
|
from typing import Any, Callable, Mapping
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""
|
|
2
15
|
Current Used Abbreviation for Tensor Dimensions:
|
|
3
16
|
B: Batch size
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import InitVar, dataclass
|
|
2
16
|
from typing import Any
|
|
3
17
|
|
tpu_inference/layers/jax/misc.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import math
|
|
2
16
|
from typing import Tuple
|
|
3
17
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|