tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""TPU-Friendly and Data-Movement-Friendly MLA Ragged Paged Attention kernel."""
|
|
2
15
|
|
|
3
16
|
import functools
|
|
@@ -16,17 +29,30 @@ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
|
|
|
16
29
|
DEFAULT_VMEM_LIMIT_BYTES = 100 * 1024 * 1024
|
|
17
30
|
|
|
18
31
|
|
|
32
|
+
def get_kv_cache_shape(
|
|
33
|
+
total_num_pages,
|
|
34
|
+
page_size,
|
|
35
|
+
kv_dim,
|
|
36
|
+
kv_dtype,
|
|
37
|
+
):
|
|
38
|
+
kv_packing = get_dtype_packing(kv_dtype)
|
|
39
|
+
return (
|
|
40
|
+
total_num_pages,
|
|
41
|
+
align_to(page_size, kv_packing) // kv_packing,
|
|
42
|
+
kv_packing,
|
|
43
|
+
align_to(kv_dim, 128),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
19
47
|
@functools.partial(
|
|
20
48
|
jax.jit,
|
|
21
|
-
donate_argnames=("
|
|
49
|
+
donate_argnames=("cache_kv"),
|
|
22
50
|
)
|
|
23
51
|
def update_kv_cache(
|
|
24
52
|
new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
|
|
25
53
|
new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
|
|
26
|
-
|
|
27
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
|
|
28
|
-
cache_k_pe: jax.
|
|
29
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
54
|
+
cache_kv: jax.
|
|
55
|
+
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim+r_dim]
|
|
30
56
|
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
31
57
|
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
32
58
|
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
@@ -43,25 +69,21 @@ def update_kv_cache(
|
|
|
43
69
|
if actual_lkv_dim != lkv_dim:
|
|
44
70
|
new_kv_c = jnp.pad(new_kv_c, ((0, 0), (0, lkv_dim - actual_lkv_dim)),
|
|
45
71
|
constant_values=0)
|
|
46
|
-
|
|
47
|
-
_, page_size_per_kv_packing, kv_packing,
|
|
48
|
-
|
|
49
|
-
assert lkv_dim == cache_lkv_dim
|
|
50
|
-
assert r_dim == cache_r_dim
|
|
72
|
+
kv_dim = r_dim + lkv_dim
|
|
73
|
+
_, page_size_per_kv_packing, kv_packing, cache_kv_dim = cache_kv.shape
|
|
74
|
+
assert kv_dim == cache_kv_dim
|
|
51
75
|
page_size = page_size_per_kv_packing * kv_packing
|
|
52
76
|
|
|
53
77
|
max_num_seqs = kv_lens.shape[0]
|
|
54
78
|
num_page_indices = page_indices.shape[0]
|
|
55
79
|
pages_per_seq = num_page_indices // max_num_seqs
|
|
56
80
|
|
|
57
|
-
def seq_loop_body(i,
|
|
58
|
-
cache_kv_c, cache_k_pe = caches
|
|
81
|
+
def seq_loop_body(i, cache_kv):
|
|
59
82
|
q_start, q_end = cu_q_lens[i], cu_q_lens[i + 1]
|
|
60
83
|
q_len = q_end - q_start
|
|
61
84
|
kv_len = kv_lens[i]
|
|
62
85
|
|
|
63
|
-
def token_loop_body(j,
|
|
64
|
-
cache_kv_c_, cache_k_pe_ = caches_
|
|
86
|
+
def token_loop_body(j, cache_kv_):
|
|
65
87
|
token_idx_in_seq = kv_len - q_len + j
|
|
66
88
|
page_num_in_seq = token_idx_in_seq // page_size
|
|
67
89
|
page_indices_start = i * pages_per_seq
|
|
@@ -69,18 +91,17 @@ def update_kv_cache(
|
|
|
69
91
|
row = (token_idx_in_seq % page_size) // kv_packing
|
|
70
92
|
col = (token_idx_in_seq % page_size) % kv_packing
|
|
71
93
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
return
|
|
94
|
+
cache_kv_ = cache_kv_.at[page_idx, row, col,
|
|
95
|
+
..., :lkv_dim].set(new_kv_c[q_start + j])
|
|
96
|
+
cache_kv_ = cache_kv_.at[page_idx, row, col, ...,
|
|
97
|
+
lkv_dim:].set(new_k_pe[q_start + j])
|
|
98
|
+
return cache_kv_
|
|
99
|
+
|
|
100
|
+
return lax.fori_loop(0, q_len, token_loop_body, cache_kv)
|
|
77
101
|
|
|
78
|
-
|
|
79
|
-
(cache_kv_c, cache_k_pe))
|
|
102
|
+
cache_kv = lax.fori_loop(0, distribution[-1], seq_loop_body, cache_kv)
|
|
80
103
|
|
|
81
|
-
|
|
82
|
-
(cache_kv_c, cache_k_pe))
|
|
83
|
-
return cache_kv_c, cache_k_pe
|
|
104
|
+
return cache_kv
|
|
84
105
|
|
|
85
106
|
|
|
86
107
|
def ref_mla_ragged_paged_attention(
|
|
@@ -88,10 +109,8 @@ def ref_mla_ragged_paged_attention(
|
|
|
88
109
|
q_pe: jax.Array, # [num_tokens, actual_num_q_heads, actual_r_dim]
|
|
89
110
|
new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
|
|
90
111
|
new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
|
|
91
|
-
|
|
112
|
+
cache_kv: jax.
|
|
92
113
|
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
|
|
93
|
-
cache_k_pe: jax.
|
|
94
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
95
114
|
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
96
115
|
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
97
116
|
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
@@ -111,8 +130,7 @@ def ref_mla_ragged_paged_attention(
|
|
|
111
130
|
q_pe,
|
|
112
131
|
new_kv_c,
|
|
113
132
|
new_k_pe,
|
|
114
|
-
|
|
115
|
-
cache_k_pe,
|
|
133
|
+
cache_kv,
|
|
116
134
|
kv_lens,
|
|
117
135
|
page_indices,
|
|
118
136
|
cu_q_lens,
|
|
@@ -123,11 +141,10 @@ def ref_mla_ragged_paged_attention(
|
|
|
123
141
|
mask_value=mask_value,
|
|
124
142
|
)
|
|
125
143
|
|
|
126
|
-
|
|
144
|
+
updated_cache_kv = update_kv_cache(
|
|
127
145
|
new_kv_c,
|
|
128
146
|
new_k_pe,
|
|
129
|
-
|
|
130
|
-
cache_k_pe,
|
|
147
|
+
cache_kv,
|
|
131
148
|
kv_lens,
|
|
132
149
|
page_indices,
|
|
133
150
|
cu_q_lens,
|
|
@@ -154,13 +171,17 @@ def ref_mla_ragged_paged_attention(
|
|
|
154
171
|
assert num_page_indices % max_num_seqs == 0
|
|
155
172
|
pages_per_seq = num_page_indices // max_num_seqs
|
|
156
173
|
|
|
157
|
-
total_num_pages, page_size_per_kv_packing, kv_packing, _ =
|
|
174
|
+
total_num_pages, page_size_per_kv_packing, kv_packing, _ = updated_cache_kv.shape
|
|
158
175
|
page_size = page_size_per_kv_packing * kv_packing
|
|
159
176
|
assert lkv_dim == ql_nope.shape[-1]
|
|
160
177
|
assert r_dim == q_pe.shape[-1]
|
|
178
|
+
assert lkv_dim + r_dim == updated_cache_kv.shape[-1]
|
|
161
179
|
|
|
162
|
-
kv_c_cache =
|
|
163
|
-
|
|
180
|
+
kv_c_cache = updated_cache_kv[..., :lkv_dim].reshape(
|
|
181
|
+
total_num_pages, page_size, lkv_dim)
|
|
182
|
+
k_pe_cache = updated_cache_kv[...,
|
|
183
|
+
lkv_dim:].reshape(total_num_pages, page_size,
|
|
184
|
+
r_dim)
|
|
164
185
|
|
|
165
186
|
outputs = []
|
|
166
187
|
|
|
@@ -221,8 +242,7 @@ def ref_mla_ragged_paged_attention(
|
|
|
221
242
|
|
|
222
243
|
return (
|
|
223
244
|
jnp.concatenate(outputs, axis=0),
|
|
224
|
-
|
|
225
|
-
cache_k_pe,
|
|
245
|
+
updated_cache_kv,
|
|
226
246
|
)
|
|
227
247
|
|
|
228
248
|
|
|
@@ -232,10 +252,8 @@ def dynamic_validate_inputs(
|
|
|
232
252
|
q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
|
|
233
253
|
new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
|
|
234
254
|
new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
|
|
235
|
-
|
|
255
|
+
cache_kv: jax.
|
|
236
256
|
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
|
|
237
|
-
cache_k_pe: jax.
|
|
238
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
239
257
|
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
240
258
|
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
241
259
|
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
@@ -260,8 +278,7 @@ def dynamic_validate_inputs(
|
|
|
260
278
|
q_pe,
|
|
261
279
|
new_kv_c,
|
|
262
280
|
new_k_pe,
|
|
263
|
-
|
|
264
|
-
cache_k_pe,
|
|
281
|
+
cache_kv,
|
|
265
282
|
kv_lens,
|
|
266
283
|
page_indices,
|
|
267
284
|
cu_q_lens,
|
|
@@ -277,8 +294,8 @@ def dynamic_validate_inputs(
|
|
|
277
294
|
debug_mode=debug_mode,
|
|
278
295
|
)
|
|
279
296
|
max_num_tokens = ql_nope.shape[0]
|
|
280
|
-
total_num_pages =
|
|
281
|
-
_, page_size_per_kv_packing, kv_packing, _ =
|
|
297
|
+
total_num_pages = cache_kv.shape[0]
|
|
298
|
+
_, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape
|
|
282
299
|
page_size = page_size_per_kv_packing * kv_packing
|
|
283
300
|
max_num_seqs = kv_lens.shape[0]
|
|
284
301
|
num_page_indices = page_indices.shape[0]
|
|
@@ -320,10 +337,8 @@ def static_validate_inputs(
|
|
|
320
337
|
q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
|
|
321
338
|
new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
|
|
322
339
|
new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
|
|
323
|
-
|
|
340
|
+
cache_kv: jax.
|
|
324
341
|
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
|
|
325
|
-
cache_k_pe: jax.
|
|
326
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
327
342
|
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
328
343
|
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
329
344
|
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
@@ -373,44 +388,34 @@ def static_validate_inputs(
|
|
|
373
388
|
|
|
374
389
|
actual_lkv_dim = ql_nope.shape[2]
|
|
375
390
|
actual_r_dim = q_pe.shape[2]
|
|
391
|
+
lkv_dim = align_to(actual_lkv_dim, 128)
|
|
392
|
+
r_dim = align_to(actual_r_dim, 128)
|
|
376
393
|
|
|
377
394
|
(
|
|
378
395
|
_,
|
|
379
396
|
page_size_per_kv_packing,
|
|
380
397
|
kv_packing,
|
|
381
|
-
|
|
382
|
-
) =
|
|
383
|
-
_, _, _, r_dim = cache_k_pe.shape
|
|
398
|
+
kv_dim,
|
|
399
|
+
) = cache_kv.shape
|
|
384
400
|
|
|
385
|
-
if lkv_dim !=
|
|
386
|
-
raise ValueError(
|
|
387
|
-
f"Expected {lkv_dim=} is equal to {align_to(actual_lkv_dim, 128)=}"
|
|
388
|
-
)
|
|
389
|
-
if r_dim != align_to(actual_r_dim, 128):
|
|
401
|
+
if lkv_dim + r_dim != kv_dim:
|
|
390
402
|
raise ValueError(
|
|
391
|
-
f"Expected {r_dim=}
|
|
403
|
+
f"Expected {lkv_dim=} + {r_dim=} to be equal to {kv_dim=}")
|
|
392
404
|
|
|
393
|
-
if not (
|
|
405
|
+
if not (cache_kv.dtype == new_kv_c.dtype):
|
|
394
406
|
raise ValueError(
|
|
395
|
-
f"Expected {
|
|
396
|
-
if not (
|
|
407
|
+
f"Expected {cache_kv.dtype=} to be equal to {new_kv_c.dtype=}.")
|
|
408
|
+
if not (cache_kv.dtype == new_k_pe.dtype):
|
|
397
409
|
raise ValueError(
|
|
398
|
-
f"Expected {
|
|
410
|
+
f"Expected {cache_kv.dtype=} to be equal to {new_k_pe.dtype=}.")
|
|
399
411
|
|
|
400
412
|
# Integer kv quantization is currently not supported.
|
|
401
|
-
if not jnp.issubdtype(
|
|
402
|
-
raise ValueError(
|
|
403
|
-
f"Expected {cache_kv_c.dtype=} to be a floating point.")
|
|
404
|
-
if not jnp.issubdtype(cache_k_pe.dtype, jnp.floating):
|
|
405
|
-
raise ValueError(
|
|
406
|
-
f"Expected {cache_k_pe.dtype=} to be a floating point.")
|
|
413
|
+
if not jnp.issubdtype(cache_kv.dtype, jnp.floating):
|
|
414
|
+
raise ValueError(f"Expected {cache_kv.dtype=} to be a floating point.")
|
|
407
415
|
|
|
408
|
-
if kv_packing != get_dtype_packing(
|
|
416
|
+
if kv_packing != get_dtype_packing(cache_kv.dtype):
|
|
409
417
|
raise ValueError(
|
|
410
|
-
f"{kv_packing=} does not match with {
|
|
411
|
-
if kv_packing != get_dtype_packing(cache_k_pe.dtype):
|
|
412
|
-
raise ValueError(
|
|
413
|
-
f"{kv_packing=} does not match with {cache_k_pe.dtype=}")
|
|
418
|
+
f"{kv_packing=} does not match with {cache_kv.dtype=}")
|
|
414
419
|
|
|
415
420
|
if not (jnp.int32 == kv_lens.dtype == page_indices.dtype == cu_q_lens.dtype
|
|
416
421
|
== distribution.dtype):
|
|
@@ -475,14 +480,12 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
475
480
|
q_pe_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, r_dim]
|
|
476
481
|
new_kv_c_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, lkv_dim]
|
|
477
482
|
new_k_pe_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, r_dim]
|
|
478
|
-
|
|
479
|
-
cache_k_pe_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
483
|
+
cache_kv_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim + r_dim, 128)]
|
|
480
484
|
# Output
|
|
481
485
|
o_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, lkv_dim]
|
|
482
|
-
|
|
483
|
-
updated_cache_k_pe_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
486
|
+
updated_cache_kv_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim + r_dim, 128)]
|
|
484
487
|
# Scratch
|
|
485
|
-
bkvc_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, lkv_dim]
|
|
488
|
+
bkvc_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, lkv_dim].
|
|
486
489
|
bkpe_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, r_dim]
|
|
487
490
|
bq_nope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, lkv_dim]
|
|
488
491
|
bq_rope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, r_dim]
|
|
@@ -505,20 +508,24 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
505
508
|
debug_mode: bool = False,
|
|
506
509
|
):
|
|
507
510
|
assert ql_nope_hbm_ref.shape == o_hbm_ref.shape
|
|
508
|
-
|
|
509
|
-
|
|
511
|
+
# Validation checks on the dimensions
|
|
512
|
+
nope_dim = ql_nope_hbm_ref.shape[-1]
|
|
513
|
+
pe_dim = q_pe_hbm_ref.shape[-1]
|
|
514
|
+
assert nope_dim + pe_dim == cache_kv_hbm_ref.shape[-1]
|
|
515
|
+
|
|
510
516
|
_, num_q_heads_per_q_packing, q_packing, lkv_dim = ql_nope_hbm_ref.shape
|
|
511
517
|
r_dim = q_pe_hbm_ref.shape[-1]
|
|
512
518
|
num_q_heads = num_q_heads_per_q_packing * q_packing
|
|
513
519
|
total_num_pages, page_size_per_kv_packing, kv_packing, _ = (
|
|
514
|
-
|
|
520
|
+
cache_kv_hbm_ref.shape)
|
|
515
521
|
max_num_seqs = kv_lens_ref.shape[0]
|
|
516
522
|
num_page_indices = page_indices_ref.shape[0]
|
|
517
523
|
|
|
518
524
|
assert num_page_indices % max_num_seqs == 0
|
|
519
525
|
pages_per_seq = num_page_indices // max_num_seqs
|
|
520
526
|
q_dtype = ql_nope_hbm_ref.dtype
|
|
521
|
-
|
|
527
|
+
# Validate against the KV dtype.
|
|
528
|
+
kv_dtype = cache_kv_hbm_ref.dtype
|
|
522
529
|
assert q_pe_hbm_ref.dtype == q_dtype
|
|
523
530
|
assert o_hbm_ref.dtype == q_dtype
|
|
524
531
|
assert get_dtype_packing(q_dtype) == q_packing
|
|
@@ -561,8 +568,8 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
561
568
|
def flash_attention(
|
|
562
569
|
ql_nope, # [actual_bq_sz * num_q_heads, lkv_dim]
|
|
563
570
|
q_pe, # [actual_bq_sz * num_q_heads, r_dim]
|
|
564
|
-
kv_c, # [bkv_sz, lkv_dim]
|
|
565
|
-
k_pe, # [bkv_sz, r_dim]
|
|
571
|
+
kv_c, # [bkv_sz, lkv_dim] <- Correspond to data from bkvc_x2_ref
|
|
572
|
+
k_pe, # [bkv_sz, r_dim] <- Correspond to data from bpe_x2_ref
|
|
566
573
|
*,
|
|
567
574
|
bq_idx,
|
|
568
575
|
bkv_idx,
|
|
@@ -649,14 +656,9 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
649
656
|
sem = sems.at[0, bkv_sem_idx]
|
|
650
657
|
bkvc_vmem_ref = bkvc_x2_ref.at[bkv_sem_idx]
|
|
651
658
|
bkvpe_vmem_ref = bkpe_x2_ref.at[bkv_sem_idx]
|
|
652
|
-
|
|
653
|
-
reshaped_cache_kv_c_hbm_ref = cache_kv_c_hbm_ref.reshape(
|
|
659
|
+
reshaped_cache_hbm_ref = cache_kv_hbm_ref.reshape(
|
|
654
660
|
total_num_pages * page_size_per_kv_packing,
|
|
655
|
-
*
|
|
656
|
-
)
|
|
657
|
-
reshaped_cache_k_pe_hbm_ref = cache_k_pe_hbm_ref.reshape(
|
|
658
|
-
total_num_pages * page_size_per_kv_packing,
|
|
659
|
-
*cache_k_pe_hbm_ref.shape[2:],
|
|
661
|
+
*cache_kv_hbm_ref.shape[2:],
|
|
660
662
|
)
|
|
661
663
|
kv_len = kv_lens_ref[seq_idx]
|
|
662
664
|
kv_len_start = bkv_idx * bkv_sz
|
|
@@ -684,22 +686,22 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
684
686
|
kv_left_per_kv_packing - i * page_size_per_kv_packing,
|
|
685
687
|
)
|
|
686
688
|
_async_copy(
|
|
687
|
-
|
|
689
|
+
reshaped_cache_hbm_ref.at[pl.ds(
|
|
688
690
|
page_indices_ref[page_indices_offset + i] *
|
|
689
691
|
page_size_per_kv_packing,
|
|
690
692
|
sz_per_kv_packing,
|
|
691
|
-
)],
|
|
693
|
+
), ..., :nope_dim],
|
|
692
694
|
bkvc_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
|
|
693
695
|
sz_per_kv_packing)],
|
|
694
696
|
sem,
|
|
695
697
|
wait,
|
|
696
698
|
)
|
|
697
699
|
_async_copy(
|
|
698
|
-
|
|
700
|
+
reshaped_cache_hbm_ref.at[pl.ds(
|
|
699
701
|
page_indices_ref[page_indices_offset + i] *
|
|
700
702
|
page_size_per_kv_packing,
|
|
701
703
|
sz_per_kv_packing,
|
|
702
|
-
)],
|
|
704
|
+
), ..., nope_dim:],
|
|
703
705
|
bkvpe_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
|
|
704
706
|
sz_per_kv_packing)],
|
|
705
707
|
sem,
|
|
@@ -820,37 +822,17 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
820
822
|
return q_nope_vec, q_rope_vec
|
|
821
823
|
|
|
822
824
|
def load_bkv(bkv_sem_idx, *, bkvc_mask, bkpe_mask):
|
|
823
|
-
bitwidth = 32 // kv_packing
|
|
824
|
-
repack_ty = jnp.dtype(f"uint{bitwidth}")
|
|
825
825
|
bkvc_ref = (bkvc_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
|
|
826
826
|
bkv_sz_per_kv_packing, lkv_dim))
|
|
827
|
-
bkvc_vec = bkvc_ref[...]
|
|
828
|
-
|
|
829
|
-
for i in range(kv_packing):
|
|
830
|
-
masked_bkvc_vec = bkvc_vec >> (i * bitwidth)
|
|
831
|
-
bkvc_vecs.append(masked_bkvc_vec)
|
|
832
|
-
concated_bkvc_vec = jnp.concatenate(bkvc_vecs, axis=-1)
|
|
833
|
-
concated_bkvc_vec = concated_bkvc_vec.reshape(bkv_sz, lkv_dim)
|
|
834
|
-
concated_bkvc_vec = lax.select(bkvc_mask, concated_bkvc_vec,
|
|
835
|
-
jnp.zeros_like(concated_bkvc_vec))
|
|
836
|
-
concated_bkvc_vec = pltpu.bitcast(concated_bkvc_vec.astype(repack_ty),
|
|
837
|
-
kv_dtype)
|
|
827
|
+
bkvc_vec = pltpu.bitcast(bkvc_ref[...], kv_dtype)
|
|
828
|
+
bkvc_vec = lax.select(bkvc_mask, bkvc_vec, jnp.zeros_like(bkvc_vec))
|
|
838
829
|
|
|
839
830
|
bkpe_ref = (bkpe_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
|
|
840
831
|
bkv_sz_per_kv_packing, r_dim))
|
|
841
|
-
bkpe_vec = bkpe_ref[...]
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
bkpe_vecs.append(masked_bkpe_vec)
|
|
846
|
-
concated_bkpe_vec = jnp.concatenate(bkpe_vecs, axis=-1)
|
|
847
|
-
concated_bkpe_vec = concated_bkpe_vec.reshape(bkv_sz, r_dim)
|
|
848
|
-
concated_bkpe_vec = lax.select(bkpe_mask, concated_bkpe_vec,
|
|
849
|
-
jnp.zeros_like(concated_bkpe_vec))
|
|
850
|
-
concated_bkpe_vec = pltpu.bitcast(concated_bkpe_vec.astype(repack_ty),
|
|
851
|
-
kv_dtype)
|
|
852
|
-
|
|
853
|
-
return concated_bkvc_vec, concated_bkpe_vec
|
|
832
|
+
bkpe_vec = pltpu.bitcast(bkpe_ref[...], kv_dtype)
|
|
833
|
+
bkpe_vec = lax.select(bkpe_mask, bkpe_vec, jnp.zeros_like(bkpe_vec))
|
|
834
|
+
|
|
835
|
+
return bkvc_vec, bkpe_vec
|
|
854
836
|
|
|
855
837
|
def broadcast_minor(src, shape):
|
|
856
838
|
if src.shape == shape:
|
|
@@ -1082,17 +1064,16 @@ def prepare_outputs(
|
|
|
1082
1064
|
"vmem_limit_bytes",
|
|
1083
1065
|
"debug_mode",
|
|
1084
1066
|
),
|
|
1085
|
-
donate_argnames=("
|
|
1067
|
+
donate_argnames=("cache_kv"),
|
|
1086
1068
|
)
|
|
1087
1069
|
def mla_ragged_paged_attention(
|
|
1088
1070
|
ql_nope: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
|
|
1089
1071
|
q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
|
|
1090
1072
|
new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
|
|
1091
1073
|
new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
1074
|
+
# TODO(gpolovets): Explore separating out into lkv & pe KV caches.
|
|
1075
|
+
cache_kv: jax.
|
|
1076
|
+
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim, 128)]
|
|
1096
1077
|
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
1097
1078
|
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
1098
1079
|
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
@@ -1124,8 +1105,7 @@ def mla_ragged_paged_attention(
|
|
|
1124
1105
|
q_pe: concatenated all sequences' rope.
|
|
1125
1106
|
new_kv_c: concatenated all sequences' kv_c values
|
|
1126
1107
|
new_k_pe: concatenated all sequences' k_pe values
|
|
1127
|
-
|
|
1128
|
-
cache_k_pe: the current k_pe cache.
|
|
1108
|
+
cache_kv: the current kv cache.
|
|
1129
1109
|
kv_lens: the length of each sequence in the kv cache.
|
|
1130
1110
|
page_indices: flattened page indices look-up table by (seq_id, page_id).
|
|
1131
1111
|
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
|
@@ -1159,8 +1139,7 @@ def mla_ragged_paged_attention(
|
|
|
1159
1139
|
q_pe,
|
|
1160
1140
|
new_kv_c,
|
|
1161
1141
|
new_k_pe,
|
|
1162
|
-
|
|
1163
|
-
cache_k_pe,
|
|
1142
|
+
cache_kv,
|
|
1164
1143
|
kv_lens,
|
|
1165
1144
|
page_indices,
|
|
1166
1145
|
cu_q_lens,
|
|
@@ -1177,11 +1156,10 @@ def mla_ragged_paged_attention(
|
|
|
1177
1156
|
)
|
|
1178
1157
|
|
|
1179
1158
|
# TODO(chengjiyao): fuse kv cache update into the kernel.
|
|
1180
|
-
|
|
1159
|
+
cache_kv = update_kv_cache(
|
|
1181
1160
|
new_kv_c,
|
|
1182
1161
|
new_k_pe,
|
|
1183
|
-
|
|
1184
|
-
cache_k_pe,
|
|
1162
|
+
cache_kv,
|
|
1185
1163
|
kv_lens,
|
|
1186
1164
|
page_indices,
|
|
1187
1165
|
cu_q_lens,
|
|
@@ -1202,7 +1180,7 @@ def mla_ragged_paged_attention(
|
|
|
1202
1180
|
lkv_dim = new_kv_c.shape[-1]
|
|
1203
1181
|
r_dim = new_k_pe.shape[-1]
|
|
1204
1182
|
|
|
1205
|
-
_, page_size_per_kv_packing, kv_packing, _ =
|
|
1183
|
+
_, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape
|
|
1206
1184
|
page_size = page_size_per_kv_packing * kv_packing
|
|
1207
1185
|
_, num_q_heads_per_q_packing, q_packing, _ = ql_nope.shape
|
|
1208
1186
|
max_num_seqs = kv_lens.shape[0]
|
|
@@ -1221,23 +1199,21 @@ def mla_ragged_paged_attention(
|
|
|
1221
1199
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1222
1200
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1223
1201
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1224
|
-
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1225
1202
|
]
|
|
1226
1203
|
|
|
1227
1204
|
out_specs = [
|
|
1228
1205
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1229
1206
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1230
|
-
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1231
1207
|
]
|
|
1232
1208
|
|
|
1233
1209
|
bkvc_double_buf = pltpu.VMEM(
|
|
1234
1210
|
(2, bkv_sz_per_kv_packing, kv_packing, lkv_dim),
|
|
1235
|
-
|
|
1211
|
+
cache_kv.dtype,
|
|
1236
1212
|
)
|
|
1237
1213
|
|
|
1238
1214
|
bkpe_double_buf = pltpu.VMEM(
|
|
1239
1215
|
(2, bkv_sz_per_kv_packing, kv_packing, r_dim),
|
|
1240
|
-
|
|
1216
|
+
cache_kv.dtype,
|
|
1241
1217
|
)
|
|
1242
1218
|
|
|
1243
1219
|
bq_nope_double_buf = pltpu.VMEM(
|
|
@@ -1320,30 +1296,26 @@ def mla_ragged_paged_attention(
|
|
|
1320
1296
|
),
|
|
1321
1297
|
out_shape=[
|
|
1322
1298
|
jax.ShapeDtypeStruct(shape=ql_nope.shape, dtype=ql_nope.dtype),
|
|
1323
|
-
jax.ShapeDtypeStruct(shape=
|
|
1324
|
-
dtype=
|
|
1325
|
-
jax.ShapeDtypeStruct(shape=cache_k_pe.shape,
|
|
1326
|
-
dtype=cache_k_pe.dtype),
|
|
1299
|
+
jax.ShapeDtypeStruct(shape=cache_kv.shape,
|
|
1300
|
+
dtype=cache_kv.dtype),
|
|
1327
1301
|
],
|
|
1328
1302
|
input_output_aliases={
|
|
1329
1303
|
7: 0,
|
|
1330
1304
|
11: 1,
|
|
1331
|
-
12: 2
|
|
1332
1305
|
},
|
|
1333
1306
|
name=scope_name,
|
|
1334
1307
|
))
|
|
1335
1308
|
|
|
1336
|
-
output,
|
|
1309
|
+
output, updated_kv = kernel(
|
|
1337
1310
|
*scalar_prefetches,
|
|
1338
1311
|
ql_nope,
|
|
1339
1312
|
q_pe,
|
|
1340
1313
|
new_kv_c,
|
|
1341
1314
|
new_k_pe,
|
|
1342
|
-
|
|
1343
|
-
cache_k_pe,
|
|
1315
|
+
cache_kv,
|
|
1344
1316
|
)
|
|
1345
1317
|
output = prepare_outputs(
|
|
1346
1318
|
output, actual_num_q_heads,
|
|
1347
1319
|
actual_lkv_dim) # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
|
|
1348
1320
|
|
|
1349
|
-
return output,
|
|
1321
|
+
return output, updated_kv
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -9,12 +9,58 @@ from jax._src import dtypes
|
|
|
9
9
|
from jax.experimental import pallas as pl
|
|
10
10
|
from jax.experimental.pallas import tpu as pltpu
|
|
11
11
|
|
|
12
|
+
from tpu_inference.kernels.quantized_matmul import util
|
|
12
13
|
from tpu_inference.kernels.quantized_matmul.tuned_block_sizes import (
|
|
13
14
|
TunedValue, get_device_vmem_limit, get_tuned_block_sizes)
|
|
14
15
|
from tpu_inference.kernels.quantized_matmul.util import (get_kernel_name,
|
|
15
16
|
next_multiple,
|
|
16
17
|
unfold_args)
|
|
17
18
|
|
|
19
|
+
quantize_tensor = util.quantize_tensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def xla_quantized_matmul(
|
|
23
|
+
x: jax.Array,
|
|
24
|
+
w_q: jax.Array,
|
|
25
|
+
w_scale: jax.Array,
|
|
26
|
+
quantize_activation=True,
|
|
27
|
+
) -> jax.Array:
|
|
28
|
+
"""
|
|
29
|
+
Reference (pure JAX) implementation of the quantized matmul kernel below.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
x: Activation.
|
|
33
|
+
w_q: Weight quantized array. [n_output_features, n_input_features]
|
|
34
|
+
w_s: Weight quantization scale. [n_output_features]
|
|
35
|
+
mesh: Mesh to shard on.
|
|
36
|
+
weight_sharding: PartitionSpec for the weight tensor.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Output of the quantized matmul.
|
|
40
|
+
"""
|
|
41
|
+
if quantize_activation:
|
|
42
|
+
acc_dtype = jnp.float32
|
|
43
|
+
if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
|
|
44
|
+
acc_dtype = jnp.int32
|
|
45
|
+
|
|
46
|
+
x_q, x_scale = quantize_tensor(x, w_q.dtype)
|
|
47
|
+
out = jax.lax.dot_general(
|
|
48
|
+
x_q,
|
|
49
|
+
w_q,
|
|
50
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
51
|
+
preferred_element_type=acc_dtype,
|
|
52
|
+
).astype(jnp.float32)
|
|
53
|
+
out *= x_scale
|
|
54
|
+
else:
|
|
55
|
+
out = jax.lax.dot_general(
|
|
56
|
+
x,
|
|
57
|
+
w_q,
|
|
58
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
59
|
+
preferred_element_type=jnp.float32,
|
|
60
|
+
)
|
|
61
|
+
out *= jnp.expand_dims(w_scale, 0)
|
|
62
|
+
return out.astype(x.dtype)
|
|
63
|
+
|
|
18
64
|
|
|
19
65
|
def quantize_array(
|
|
20
66
|
x: jax.Array, # [bs_block_size, in_block_size]
|
|
@@ -50,11 +96,20 @@ def get_vmem_limit(
|
|
|
50
96
|
"""Calculate VMEM limit for the kernel."""
|
|
51
97
|
|
|
52
98
|
# Calculate in/out VMEM size.
|
|
53
|
-
x_size = batch_block_size *
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
99
|
+
x_size = (batch_block_size *
|
|
100
|
+
in_block_size * (dtypes.bit_width(x_dtype) if hasattr(
|
|
101
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(x_dtype)))
|
|
102
|
+
x_abs_max_size = (
|
|
103
|
+
batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
|
|
104
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
|
|
105
|
+
w_q_size = (out_block_size *
|
|
106
|
+
in_block_size * (dtypes.bit_width(w_q_dtype) if hasattr(
|
|
107
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(w_q_dtype)))
|
|
108
|
+
w_scale_size = (out_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
|
|
109
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
|
|
110
|
+
out_size = (batch_block_size *
|
|
111
|
+
out_block_size * (dtypes.bit_width(out_dtype) if hasattr(
|
|
112
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(out_dtype)))
|
|
58
113
|
|
|
59
114
|
vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
|
|
60
115
|
vmem_in_out *= 2 # Account for compute and vreg spills.
|
|
@@ -68,9 +123,15 @@ def get_vmem_limit(
|
|
|
68
123
|
vmem_in_out += out_size if (n_batch > 1 or n_out > 1) else 0
|
|
69
124
|
|
|
70
125
|
# Calculate scratch VMEM size.
|
|
71
|
-
acc_size = batch_block_size *
|
|
72
|
-
|
|
73
|
-
|
|
126
|
+
acc_size = (batch_block_size *
|
|
127
|
+
out_block_size * (dtypes.bit_width(acc_dtype) if hasattr(
|
|
128
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(acc_dtype)))
|
|
129
|
+
x_q_size = (batch_block_size *
|
|
130
|
+
in_block_size * (dtypes.bit_width(x_q_dtype) if hasattr(
|
|
131
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(x_q_dtype)))
|
|
132
|
+
x_scale_size = (
|
|
133
|
+
batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
|
|
134
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
|
|
74
135
|
|
|
75
136
|
vmem_scratch = acc_size if save_acc else 0
|
|
76
137
|
vmem_scratch += x_q_size + x_scale_size if save_x_q else 0
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|