tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
from absl.testing import absltest, parameterized
|
|
18
|
+
from jax._src import test_util as jtu
|
|
19
|
+
|
|
20
|
+
from tpu_inference.kernels.megablox.gmm import gmm
|
|
21
|
+
|
|
22
|
+
jax.config.parse_flags_with_absl()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def quantize_tensor(x: jax.Array,
|
|
26
|
+
dtype: jnp.dtype,
|
|
27
|
+
axis: int = -1,
|
|
28
|
+
block_size: int = 256):
|
|
29
|
+
if jnp.issubdtype(dtype, jnp.integer):
|
|
30
|
+
dtype_info = jnp.iinfo(dtype)
|
|
31
|
+
max_val = int(dtype_info.max)
|
|
32
|
+
min_val = int(dtype_info.min)
|
|
33
|
+
else:
|
|
34
|
+
dtype_info = jnp.finfo(dtype)
|
|
35
|
+
max_val = float(dtype_info.max)
|
|
36
|
+
min_val = float(dtype_info.min)
|
|
37
|
+
|
|
38
|
+
orig_shape = x.shape
|
|
39
|
+
blocked_shape = orig_shape[:axis] + (-1,
|
|
40
|
+
block_size) + orig_shape[axis + 1:]
|
|
41
|
+
x_blocked = x.reshape(blocked_shape)
|
|
42
|
+
|
|
43
|
+
x_blocked_abs_max = jnp.max(jnp.abs(x_blocked),
|
|
44
|
+
axis=axis + 1,
|
|
45
|
+
keepdims=True)
|
|
46
|
+
scale = x_blocked_abs_max / max_val
|
|
47
|
+
x_blocked_q = jnp.clip(x_blocked / scale, min_val, max_val).astype(dtype)
|
|
48
|
+
|
|
49
|
+
x_q = x_blocked_q.reshape(orig_shape)
|
|
50
|
+
scale = scale.squeeze(axis=axis + 1).astype(jnp.float32)
|
|
51
|
+
return x_q, scale
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def reference_gmm(
|
|
55
|
+
lhs: jax.Array,
|
|
56
|
+
rhs: jax.Array,
|
|
57
|
+
group_sizes: jax.Array,
|
|
58
|
+
rhs_scale: jax.Array | None = None,
|
|
59
|
+
rhs_bias: jax.Array | None = None,
|
|
60
|
+
group_offset: jax.Array | None = None,
|
|
61
|
+
):
|
|
62
|
+
num_groups, out_size, in_size = rhs.shape
|
|
63
|
+
assert lhs.shape[1] == in_size
|
|
64
|
+
|
|
65
|
+
if group_offset is None:
|
|
66
|
+
group_offset = jnp.array(0, dtype=jnp.int32)
|
|
67
|
+
start = group_sizes[:group_offset].sum()
|
|
68
|
+
group_sizes = group_sizes[group_offset:]
|
|
69
|
+
assert len(group_sizes) == num_groups
|
|
70
|
+
|
|
71
|
+
if rhs_scale is not None:
|
|
72
|
+
num_blocks = rhs_scale.shape[1]
|
|
73
|
+
else:
|
|
74
|
+
num_blocks = 1
|
|
75
|
+
block_size = in_size // num_blocks
|
|
76
|
+
|
|
77
|
+
gmm_out = [jnp.zeros((start, out_size), lhs.dtype)]
|
|
78
|
+
for group in range(num_groups):
|
|
79
|
+
end = start + group_sizes[group]
|
|
80
|
+
|
|
81
|
+
lhs_slice = lhs[start:end]
|
|
82
|
+
rhs_slice = rhs[group]
|
|
83
|
+
|
|
84
|
+
out = 0
|
|
85
|
+
for block in range(num_blocks):
|
|
86
|
+
block_start = block * block_size
|
|
87
|
+
block_end = block_start + block_size
|
|
88
|
+
lhs_block = lhs_slice[:, block_start:block_end].astype(jnp.float32)
|
|
89
|
+
rhs_block = rhs_slice[:, block_start:block_end].astype(jnp.float32)
|
|
90
|
+
|
|
91
|
+
acc = jnp.einsum("bd,hd->bh", lhs_block, rhs_block)
|
|
92
|
+
if rhs_scale is not None:
|
|
93
|
+
acc *= rhs_scale[group][block]
|
|
94
|
+
out += acc
|
|
95
|
+
if rhs_bias is not None:
|
|
96
|
+
out = out + rhs_bias[group]
|
|
97
|
+
|
|
98
|
+
gmm_out.append(out.astype(lhs.dtype))
|
|
99
|
+
start = end
|
|
100
|
+
|
|
101
|
+
return jnp.concat(gmm_out, axis=0)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
105
|
+
class GmmTest(jtu.JaxTestCase):
|
|
106
|
+
|
|
107
|
+
@parameterized.product(
|
|
108
|
+
batch_size=[128],
|
|
109
|
+
in_size=[1024],
|
|
110
|
+
out_size=[1024],
|
|
111
|
+
num_groups=[16, 32],
|
|
112
|
+
has_bias=[True, False],
|
|
113
|
+
)
|
|
114
|
+
def test_gmm(self, batch_size, in_size, out_size, num_groups, has_bias):
|
|
115
|
+
key = jax.random.key(0)
|
|
116
|
+
|
|
117
|
+
lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
|
|
118
|
+
rhs = jax.random.normal(key, (num_groups, out_size, in_size),
|
|
119
|
+
dtype=jnp.bfloat16)
|
|
120
|
+
rhs_bias = None
|
|
121
|
+
if has_bias:
|
|
122
|
+
rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
|
|
123
|
+
dtype=jnp.bfloat16)
|
|
124
|
+
|
|
125
|
+
group_sizes = jax.random.randint(key, (num_groups, ),
|
|
126
|
+
0,
|
|
127
|
+
batch_size,
|
|
128
|
+
dtype=jnp.int32)
|
|
129
|
+
|
|
130
|
+
expected = reference_gmm(lhs, rhs, group_sizes, rhs_bias=rhs_bias)
|
|
131
|
+
actual = gmm(
|
|
132
|
+
lhs,
|
|
133
|
+
rhs,
|
|
134
|
+
group_sizes,
|
|
135
|
+
rhs_bias=rhs_bias,
|
|
136
|
+
transpose_rhs=True,
|
|
137
|
+
preferred_element_type=jnp.bfloat16,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
self.assertArraysAllClose(actual, expected)
|
|
141
|
+
|
|
142
|
+
@parameterized.product(
|
|
143
|
+
batch_size=[128],
|
|
144
|
+
in_size=[1024],
|
|
145
|
+
out_size=[1024],
|
|
146
|
+
num_groups=[16, 32],
|
|
147
|
+
has_bias=[True, False],
|
|
148
|
+
weight_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn],
|
|
149
|
+
block_size=[256, 512],
|
|
150
|
+
)
|
|
151
|
+
def test_gmm_weight_quantized(
|
|
152
|
+
self,
|
|
153
|
+
batch_size,
|
|
154
|
+
in_size,
|
|
155
|
+
out_size,
|
|
156
|
+
num_groups,
|
|
157
|
+
has_bias,
|
|
158
|
+
weight_dtype,
|
|
159
|
+
block_size,
|
|
160
|
+
):
|
|
161
|
+
if weight_dtype == jnp.float4_e2m1fn and not jtu.is_device_tpu_at_least(
|
|
162
|
+
version=7):
|
|
163
|
+
self.skipTest("Expect TPUv7+")
|
|
164
|
+
key = jax.random.key(0)
|
|
165
|
+
|
|
166
|
+
lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
|
|
167
|
+
rhs = jax.random.normal(key, (num_groups, out_size, in_size),
|
|
168
|
+
dtype=jnp.bfloat16)
|
|
169
|
+
rhs_q, rhs_scale = quantize_tensor(rhs,
|
|
170
|
+
weight_dtype,
|
|
171
|
+
axis=2,
|
|
172
|
+
block_size=block_size)
|
|
173
|
+
rhs_scale = jnp.swapaxes(rhs_scale, 1, 2)
|
|
174
|
+
rhs_scale = jnp.expand_dims(rhs_scale, axis=2)
|
|
175
|
+
|
|
176
|
+
rhs_bias = None
|
|
177
|
+
if has_bias:
|
|
178
|
+
rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
|
|
179
|
+
dtype=jnp.bfloat16)
|
|
180
|
+
|
|
181
|
+
group_sizes = jax.random.randint(key, (num_groups, ),
|
|
182
|
+
0,
|
|
183
|
+
batch_size,
|
|
184
|
+
dtype=jnp.int32)
|
|
185
|
+
|
|
186
|
+
expected = reference_gmm(lhs,
|
|
187
|
+
rhs_q,
|
|
188
|
+
group_sizes,
|
|
189
|
+
rhs_scale=rhs_scale,
|
|
190
|
+
rhs_bias=rhs_bias)
|
|
191
|
+
actual = gmm(
|
|
192
|
+
lhs,
|
|
193
|
+
rhs_q,
|
|
194
|
+
group_sizes,
|
|
195
|
+
rhs_scale=rhs_scale,
|
|
196
|
+
rhs_bias=rhs_bias,
|
|
197
|
+
transpose_rhs=True,
|
|
198
|
+
preferred_element_type=jnp.bfloat16,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
self.assertArraysAllClose(actual, expected, atol=3e-1, rtol=3e-1)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
if __name__ == "__main__":
|
|
205
|
+
absltest.main(testLoader=jtu.JaxTestLoader())
|
tests/kernels/mla_v1_test.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import jax
|
|
2
16
|
import jax.numpy as jnp
|
|
3
17
|
import numpy as np
|
|
@@ -42,6 +56,7 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
|
42
56
|
|
|
43
57
|
padded_r_dim = align_to(r_dim, 128)
|
|
44
58
|
padded_lkv_dim = align_to(lkv_dim, 128)
|
|
59
|
+
padded_kv_dim = padded_lkv_dim + padded_r_dim
|
|
45
60
|
packing = get_dtype_packing(kv_dtype)
|
|
46
61
|
q_lens = [s[0] for s in seq_lens]
|
|
47
62
|
kv_lens_list = [s[1] for s in seq_lens]
|
|
@@ -69,13 +84,10 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
|
69
84
|
new_kv_c = gen_random((total_q_len, lkv_dim), kv_dtype)
|
|
70
85
|
new_k_pe = gen_random((total_q_len, r_dim), kv_dtype)
|
|
71
86
|
|
|
72
|
-
|
|
73
|
-
(total_num_pages, page_size // packing, packing,
|
|
87
|
+
cache_kv = gen_random(
|
|
88
|
+
(total_num_pages, page_size // packing, packing, padded_kv_dim),
|
|
74
89
|
kv_dtype,
|
|
75
90
|
)
|
|
76
|
-
cache_k_pe = gen_random(
|
|
77
|
-
(total_num_pages, page_size // packing, packing, padded_r_dim),
|
|
78
|
-
kv_dtype)
|
|
79
91
|
kv_lens = jnp.array(kv_lens_list, dtype=jnp.int32)
|
|
80
92
|
page_indices = jnp.array(page_indices_list, dtype=jnp.int32)
|
|
81
93
|
cu_q_lens = jnp.array(cu_q_lens_list, dtype=jnp.int32)
|
|
@@ -84,14 +96,13 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
|
84
96
|
ql_nope_for_kernel = ql_nope.copy()
|
|
85
97
|
q_pe_for_kernel = q_pe.copy()
|
|
86
98
|
|
|
87
|
-
expected_out,
|
|
99
|
+
expected_out, expected_updated_kv = (
|
|
88
100
|
mla.ref_mla_ragged_paged_attention(
|
|
89
101
|
ql_nope,
|
|
90
102
|
q_pe,
|
|
91
103
|
new_kv_c,
|
|
92
104
|
new_k_pe,
|
|
93
|
-
|
|
94
|
-
cache_k_pe.copy(),
|
|
105
|
+
cache_kv.copy(),
|
|
95
106
|
kv_lens,
|
|
96
107
|
page_indices,
|
|
97
108
|
cu_q_lens,
|
|
@@ -101,50 +112,141 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
|
101
112
|
soft_cap=soft_cap,
|
|
102
113
|
))
|
|
103
114
|
|
|
104
|
-
kernel_out,
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
vmem_limit_bytes=vmem_limit_bytes,
|
|
122
|
-
))
|
|
115
|
+
kernel_out, kernel_updated_kv = (mla.mla_ragged_paged_attention(
|
|
116
|
+
ql_nope_for_kernel,
|
|
117
|
+
q_pe_for_kernel,
|
|
118
|
+
new_kv_c,
|
|
119
|
+
new_k_pe,
|
|
120
|
+
cache_kv.copy(),
|
|
121
|
+
kv_lens,
|
|
122
|
+
page_indices,
|
|
123
|
+
cu_q_lens,
|
|
124
|
+
distribution,
|
|
125
|
+
sm_scale=sm_scale,
|
|
126
|
+
sliding_window=sliding_window,
|
|
127
|
+
soft_cap=soft_cap,
|
|
128
|
+
num_kv_pages_per_block=num_kv_pages_per_block,
|
|
129
|
+
num_queries_per_block=num_queries_per_block,
|
|
130
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
131
|
+
))
|
|
123
132
|
|
|
124
133
|
self.assertEqual(expected_out.shape,
|
|
125
134
|
(total_q_len, num_heads, padded_lkv_dim))
|
|
126
135
|
self.assertEqual(
|
|
127
|
-
|
|
128
|
-
(total_num_pages, page_size // packing, packing,
|
|
129
|
-
)
|
|
130
|
-
self.assertEqual(
|
|
131
|
-
expeceted_updated_k_pe.shape,
|
|
132
|
-
(total_num_pages, page_size // packing, packing, padded_r_dim),
|
|
136
|
+
expected_updated_kv.shape,
|
|
137
|
+
(total_num_pages, page_size // packing, packing, padded_kv_dim),
|
|
133
138
|
)
|
|
134
139
|
self.assertEqual(expected_out.dtype, kv_dtype)
|
|
135
|
-
self.assertEqual(
|
|
136
|
-
self.assertEqual(expeceted_updated_k_pe.dtype, kv_dtype)
|
|
140
|
+
self.assertEqual(expected_updated_kv.dtype, kv_dtype)
|
|
137
141
|
|
|
138
142
|
self.assertAllClose(expected_out, kernel_out, atol=0.2, rtol=0.2)
|
|
139
|
-
self.assertAllClose(
|
|
140
|
-
|
|
141
|
-
atol=0.2,
|
|
142
|
-
rtol=0.2)
|
|
143
|
-
self.assertAllClose(expeceted_updated_k_pe,
|
|
144
|
-
kernel_updated_k_pe,
|
|
143
|
+
self.assertAllClose(expected_updated_kv,
|
|
144
|
+
kernel_updated_kv,
|
|
145
145
|
atol=0.2,
|
|
146
146
|
rtol=0.2)
|
|
147
147
|
|
|
148
|
+
def test_update_kv_cache(self):
|
|
149
|
+
lkv_dim = 4
|
|
150
|
+
r_dim = 4
|
|
151
|
+
padded_lkv_dim = align_to(lkv_dim, 128)
|
|
152
|
+
padded_r_dim = align_to(r_dim, 128)
|
|
153
|
+
kv_dtype = jnp.bfloat16
|
|
154
|
+
new_kv_c = jnp.arange(16, dtype=kv_dtype).reshape((4, lkv_dim))
|
|
155
|
+
new_k_pe = (jnp.arange(16, dtype=kv_dtype).reshape((4, r_dim)) + 100)
|
|
156
|
+
total_num_pages = 2
|
|
157
|
+
page_size = 4
|
|
158
|
+
cache_kv_shape = mla.get_kv_cache_shape(
|
|
159
|
+
total_num_pages,
|
|
160
|
+
page_size,
|
|
161
|
+
padded_lkv_dim + padded_r_dim,
|
|
162
|
+
kv_dtype,
|
|
163
|
+
)
|
|
164
|
+
cache_kv = jnp.zeros(cache_kv_shape, dtype=kv_dtype)
|
|
165
|
+
|
|
166
|
+
# two sequences, first with 3 tokens, second with 1 token
|
|
167
|
+
kv_lens = jnp.array([3, 1], dtype=jnp.int32)
|
|
168
|
+
# first seq uses page 0, second uses page 1
|
|
169
|
+
page_indices = jnp.array([0, -1, 1, -1], dtype=jnp.int32)
|
|
170
|
+
# three tokens for first seq, one for second
|
|
171
|
+
cu_q_lens = jnp.array([0, 3, 4], dtype=jnp.int32)
|
|
172
|
+
distribution = jnp.array([0, 0, 2], dtype=jnp.int32)
|
|
173
|
+
|
|
174
|
+
# manually compute the expected cache
|
|
175
|
+
padded_new_kv_c = jnp.pad(new_kv_c,
|
|
176
|
+
((0, 0), (0, padded_lkv_dim - lkv_dim)),
|
|
177
|
+
constant_values=0)
|
|
178
|
+
padded_new_k_pe = jnp.pad(new_k_pe,
|
|
179
|
+
((0, 0), (0, padded_r_dim - r_dim)),
|
|
180
|
+
constant_values=0)
|
|
181
|
+
|
|
182
|
+
expected_cache = cache_kv
|
|
183
|
+
# First sequence
|
|
184
|
+
# token 0
|
|
185
|
+
page_idx, row, col = 0, 0, 0
|
|
186
|
+
expected_cache = expected_cache.at[page_idx, row,
|
|
187
|
+
col, :padded_lkv_dim].set(
|
|
188
|
+
padded_new_kv_c[0])
|
|
189
|
+
expected_cache = expected_cache.at[page_idx, row, col,
|
|
190
|
+
padded_lkv_dim:padded_lkv_dim +
|
|
191
|
+
padded_r_dim].set(
|
|
192
|
+
padded_new_k_pe[0])
|
|
193
|
+
# token 1
|
|
194
|
+
page_idx, row, col = 0, 0, 1
|
|
195
|
+
expected_cache = expected_cache.at[page_idx, row,
|
|
196
|
+
col, :padded_lkv_dim].set(
|
|
197
|
+
padded_new_kv_c[1])
|
|
198
|
+
expected_cache = expected_cache.at[page_idx, row, col,
|
|
199
|
+
padded_lkv_dim:padded_lkv_dim +
|
|
200
|
+
padded_r_dim].set(
|
|
201
|
+
padded_new_k_pe[1])
|
|
202
|
+
# token 2
|
|
203
|
+
page_idx, row, col = 0, 1, 0
|
|
204
|
+
expected_cache = expected_cache.at[page_idx, row,
|
|
205
|
+
col, :padded_lkv_dim].set(
|
|
206
|
+
padded_new_kv_c[2])
|
|
207
|
+
expected_cache = expected_cache.at[page_idx, row, col,
|
|
208
|
+
padded_lkv_dim:padded_lkv_dim +
|
|
209
|
+
padded_r_dim].set(
|
|
210
|
+
padded_new_k_pe[2])
|
|
211
|
+
|
|
212
|
+
# Second sequence
|
|
213
|
+
# token 0
|
|
214
|
+
page_idx, row, col = 1, 0, 0
|
|
215
|
+
expected_cache = expected_cache.at[page_idx, row,
|
|
216
|
+
col, :padded_lkv_dim].set(
|
|
217
|
+
padded_new_kv_c[3])
|
|
218
|
+
expected_cache = expected_cache.at[page_idx, row, col,
|
|
219
|
+
padded_lkv_dim:padded_lkv_dim +
|
|
220
|
+
padded_r_dim].set(
|
|
221
|
+
padded_new_k_pe[3])
|
|
222
|
+
|
|
223
|
+
updated_cache = mla.update_kv_cache(
|
|
224
|
+
new_kv_c,
|
|
225
|
+
new_k_pe,
|
|
226
|
+
cache_kv,
|
|
227
|
+
kv_lens,
|
|
228
|
+
page_indices,
|
|
229
|
+
cu_q_lens,
|
|
230
|
+
distribution,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
self.assertAllClose(updated_cache, expected_cache)
|
|
234
|
+
|
|
235
|
+
def test_get_kv_cache_shape(self):
|
|
236
|
+
total_num_pages = 10
|
|
237
|
+
page_size = 16
|
|
238
|
+
lkv_dim = 128
|
|
239
|
+
kv_dtype = jnp.bfloat16
|
|
240
|
+
# The calculation for the expected shape is as follows:
|
|
241
|
+
# kv_packing is determined by the dtype, which is 2 for bfloat16.
|
|
242
|
+
# The second dimension is page_size / kv_packing = 16 / 2 = 8
|
|
243
|
+
# The third dimension is kv_packing = 2
|
|
244
|
+
# The fourth dimension is lkv_dim aligned to 128, which is 128
|
|
245
|
+
expected_shape = (10, 8, 2, 128)
|
|
246
|
+
self.assertEqual(
|
|
247
|
+
mla.get_kv_cache_shape(total_num_pages, page_size, lkv_dim,
|
|
248
|
+
kv_dtype), expected_shape)
|
|
249
|
+
|
|
148
250
|
def test_ragged_paged_attention_basic(self):
|
|
149
251
|
dtype = jnp.bfloat16
|
|
150
252
|
seq_lens = [(192, 328), (128, 180), (64, 255)]
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import functools
|
|
4
|
-
|
|
5
3
|
import jax
|
|
6
4
|
import jax.numpy as jnp
|
|
7
5
|
from absl.testing import absltest, parameterized
|
|
@@ -10,6 +8,7 @@ from jax._src import test_util as jtu
|
|
|
10
8
|
from tpu_inference.kernels.quantized_matmul import (kernel, tuned_block_sizes,
|
|
11
9
|
util)
|
|
12
10
|
|
|
11
|
+
xla_quantized_matmul = kernel.xla_quantized_matmul
|
|
13
12
|
quantized_matmul_kernel = kernel.quantized_matmul_kernel
|
|
14
13
|
quantize_tensor = util.quantize_tensor
|
|
15
14
|
get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
|
|
@@ -17,37 +16,6 @@ get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
|
|
|
17
16
|
jax.config.parse_flags_with_absl()
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
@functools.partial(jax.jit, static_argnames=["quantize_activation"])
|
|
21
|
-
def reference_quantized_matmul(
|
|
22
|
-
x: jax.Array,
|
|
23
|
-
w_q: jax.Array,
|
|
24
|
-
w_scale: jax.Array,
|
|
25
|
-
quantize_activation=True,
|
|
26
|
-
):
|
|
27
|
-
if quantize_activation:
|
|
28
|
-
acc_dtype = jnp.float32
|
|
29
|
-
if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
|
|
30
|
-
acc_dtype = jnp.int32
|
|
31
|
-
|
|
32
|
-
x_q, x_scale = quantize_tensor(x, w_q.dtype)
|
|
33
|
-
out = jax.lax.dot_general(
|
|
34
|
-
x_q,
|
|
35
|
-
w_q,
|
|
36
|
-
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
37
|
-
preferred_element_type=acc_dtype,
|
|
38
|
-
).astype(jnp.float32)
|
|
39
|
-
out *= x_scale
|
|
40
|
-
else:
|
|
41
|
-
out = jax.lax.dot_general(
|
|
42
|
-
x,
|
|
43
|
-
w_q,
|
|
44
|
-
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
45
|
-
preferred_element_type=jnp.float32,
|
|
46
|
-
)
|
|
47
|
-
out *= jnp.expand_dims(w_scale, 0)
|
|
48
|
-
return out.astype(x.dtype)
|
|
49
|
-
|
|
50
|
-
|
|
51
19
|
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
52
20
|
class QuantizedMatmulKernelTest(jtu.JaxTestCase):
|
|
53
21
|
|
|
@@ -94,7 +62,7 @@ class QuantizedMatmulKernelTest(jtu.JaxTestCase):
|
|
|
94
62
|
x_q_dtype=x_q_dtype,
|
|
95
63
|
tuned_value=tuned_value,
|
|
96
64
|
)
|
|
97
|
-
expected =
|
|
65
|
+
expected = xla_quantized_matmul(
|
|
98
66
|
x, w_q, w_scale, quantize_activation=quantize_activation)
|
|
99
67
|
|
|
100
68
|
self.assertAllClose(output,
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import jax
|
|
2
16
|
import jax.numpy as jnp
|
|
3
17
|
import numpy as np
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import random
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import jax
|
|
2
16
|
import jax.numpy as jnp
|
|
3
17
|
import numpy as np
|
|
@@ -176,7 +190,9 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
|
|
|
176
190
|
)
|
|
177
191
|
output = output[:cu_q_lens[distribution[-1]]]
|
|
178
192
|
|
|
179
|
-
dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
|
|
193
|
+
dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
|
|
194
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(
|
|
195
|
+
jnp.dtype(kv_dtype)))
|
|
180
196
|
tols = {
|
|
181
197
|
32: 0.15,
|
|
182
198
|
16: 0.2,
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import jax
|
|
2
16
|
import jax.numpy as jnp
|
|
3
17
|
import numpy as np
|
|
@@ -162,7 +176,9 @@ class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
|
162
176
|
)
|
|
163
177
|
output = output[:cu_q_lens[distribution[-1]]]
|
|
164
178
|
|
|
165
|
-
dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
|
|
179
|
+
dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
|
|
180
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(
|
|
181
|
+
jnp.dtype(kv_dtype)))
|
|
166
182
|
tols = {
|
|
167
183
|
32: 0.15,
|
|
168
184
|
16: 0.2,
|
tests/layers/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|