tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
|
|
3
|
+
from tpu_inference.core.disagg_utils import _parse_slices
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DisaggUtilsTest(unittest.TestCase):
|
|
7
|
+
|
|
8
|
+
def test_parse_slices_valid_cases(self):
|
|
9
|
+
"""Tests valid slice strings."""
|
|
10
|
+
# Test with a single slice
|
|
11
|
+
self.assertEqual(_parse_slices("2x2"), ((2, 2), ))
|
|
12
|
+
self.assertEqual(_parse_slices("2"), (2, ))
|
|
13
|
+
|
|
14
|
+
# Test with multiple slices
|
|
15
|
+
self.assertEqual(_parse_slices("2x2,2x1,3,2x4"),
|
|
16
|
+
((2, 2), (2, 1), 3, (2, 4)))
|
|
17
|
+
|
|
18
|
+
# Test with various dimensions
|
|
19
|
+
self.assertEqual(_parse_slices("1x1,10x10,5x3"),
|
|
20
|
+
((1, 1), (10, 10), (5, 3)))
|
|
21
|
+
|
|
22
|
+
# Test with an empty string
|
|
23
|
+
self.assertEqual(_parse_slices(""), ())
|
|
24
|
+
|
|
25
|
+
def test_parse_slices_with_whitespace(self):
|
|
26
|
+
"""Tests valid slice strings with extra whitespace."""
|
|
27
|
+
self.assertEqual(_parse_slices(" 2x2 "), ((2, 2), ))
|
|
28
|
+
self.assertEqual(_parse_slices(" 2x2 , 2x1 , 2x4 "),
|
|
29
|
+
((2, 2), (2, 1), (2, 4)))
|
|
30
|
+
# The current implementation allows spaces inside the slice definition
|
|
31
|
+
self.assertEqual(_parse_slices("2 x 2"), ((2, 2), ))
|
|
32
|
+
self.assertEqual(_parse_slices(" 10 x 10 "), ((10, 10), ))
|
|
33
|
+
|
|
34
|
+
def test_parse_slices_invalid_cases(self):
|
|
35
|
+
"""Tests malformed slice strings that should raise ValueError."""
|
|
36
|
+
invalid_strings = [
|
|
37
|
+
"2*2", # wrong separator
|
|
38
|
+
"2x", # incomplete
|
|
39
|
+
"axb", # not integers
|
|
40
|
+
"2x2x2", # too many dimensions
|
|
41
|
+
"2x2,3*3", # partially malformed
|
|
42
|
+
",2x2", # leading comma
|
|
43
|
+
"2x2,", # trailing comma
|
|
44
|
+
"2x2,,2x1", # empty slice in middle
|
|
45
|
+
]
|
|
46
|
+
for invalid_str in invalid_strings:
|
|
47
|
+
with self.subTest(invalid_str=invalid_str):
|
|
48
|
+
with self.assertRaises(ValueError):
|
|
49
|
+
_parse_slices(invalid_str)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
if __name__ == '__main__':
|
|
53
|
+
unittest.main()
|
tests/core/test_init.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import unittest
|
|
3
|
+
from unittest.mock import patch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TestPathwaysInit(unittest.TestCase):
|
|
7
|
+
|
|
8
|
+
@patch.dict("os.environ", {"JAX_PLATFORMS": "proxy,cpu"})
|
|
9
|
+
def test_VLLM_TPU_USING_PATHWAYS_enabled(self):
|
|
10
|
+
"""Test when JAX_PLATFORMS contains 'proxy'."""
|
|
11
|
+
|
|
12
|
+
# Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
|
|
13
|
+
import vllm.envs as envs
|
|
14
|
+
|
|
15
|
+
# Reload the module to ensure fresh import
|
|
16
|
+
importlib.reload(envs)
|
|
17
|
+
|
|
18
|
+
# Check that VLLM_TPU_USING_PATHWAYS is True when JAX_PLATFORMS contains "proxy"
|
|
19
|
+
self.assertTrue(envs.VLLM_TPU_USING_PATHWAYS)
|
|
20
|
+
|
|
21
|
+
@patch.dict("os.environ", {"JAX_PLATFORMS": "cpu"})
|
|
22
|
+
def test_VLLM_TPU_USING_PATHWAYS_not_enabled(self):
|
|
23
|
+
"""Test when JAX_PLATFORMS does not contain 'proxy'."""
|
|
24
|
+
|
|
25
|
+
# Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
|
|
26
|
+
import vllm.envs as envs
|
|
27
|
+
|
|
28
|
+
# Reload the module to ensure fresh import
|
|
29
|
+
importlib.reload(envs)
|
|
30
|
+
|
|
31
|
+
# Check that VLLM_TPU_USING_PATHWAYS is False when JAX_PLATFORMS doesn't contain "proxy"
|
|
32
|
+
self.assertFalse(envs.VLLM_TPU_USING_PATHWAYS)
|
|
33
|
+
|
|
34
|
+
@patch.dict("os.environ", {"JAX_PLATFORMS": "PROXY,CPU"})
|
|
35
|
+
def test_VLLM_TPU_USING_PATHWAYS_case_insensitive(self):
|
|
36
|
+
"""Test that JAX_PLATFORMS check is case insensitive."""
|
|
37
|
+
|
|
38
|
+
# Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
|
|
39
|
+
import vllm.envs as envs
|
|
40
|
+
|
|
41
|
+
# Reload the module to ensure fresh import
|
|
42
|
+
importlib.reload(envs)
|
|
43
|
+
|
|
44
|
+
# Check that VLLM_TPU_USING_PATHWAYS is True even with uppercase "PROXY"
|
|
45
|
+
self.assertTrue(envs.VLLM_TPU_USING_PATHWAYS)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
if __name__ == "__main__":
|
|
49
|
+
unittest.main()
|
|
File without changes
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
from absl.testing import absltest, parameterized
|
|
8
|
+
from jax._src import test_util as jtu
|
|
9
|
+
|
|
10
|
+
from tpu_inference.kernels.quantized_matmul import (kernel, tuned_block_sizes,
|
|
11
|
+
util)
|
|
12
|
+
|
|
13
|
+
quantized_matmul_kernel = kernel.quantized_matmul_kernel
|
|
14
|
+
quantize_tensor = util.quantize_tensor
|
|
15
|
+
get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
|
|
16
|
+
|
|
17
|
+
jax.config.parse_flags_with_absl()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@functools.partial(jax.jit, static_argnames=["quantize_activation"])
|
|
21
|
+
def reference_quantized_matmul(
|
|
22
|
+
x: jax.Array,
|
|
23
|
+
w_q: jax.Array,
|
|
24
|
+
w_scale: jax.Array,
|
|
25
|
+
quantize_activation=True,
|
|
26
|
+
):
|
|
27
|
+
if quantize_activation:
|
|
28
|
+
acc_dtype = jnp.float32
|
|
29
|
+
if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
|
|
30
|
+
acc_dtype = jnp.int32
|
|
31
|
+
|
|
32
|
+
x_q, x_scale = quantize_tensor(x, w_q.dtype)
|
|
33
|
+
out = jax.lax.dot_general(
|
|
34
|
+
x_q,
|
|
35
|
+
w_q,
|
|
36
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
37
|
+
preferred_element_type=acc_dtype,
|
|
38
|
+
).astype(jnp.float32)
|
|
39
|
+
out *= x_scale
|
|
40
|
+
else:
|
|
41
|
+
out = jax.lax.dot_general(
|
|
42
|
+
x,
|
|
43
|
+
w_q,
|
|
44
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
45
|
+
preferred_element_type=jnp.float32,
|
|
46
|
+
)
|
|
47
|
+
out *= jnp.expand_dims(w_scale, 0)
|
|
48
|
+
return out.astype(x.dtype)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
52
|
+
class QuantizedMatmulKernelTest(jtu.JaxTestCase):
|
|
53
|
+
|
|
54
|
+
def setUp(self):
|
|
55
|
+
super().setUp()
|
|
56
|
+
if not jtu.is_device_tpu_at_least(6):
|
|
57
|
+
self.skipTest("Expect TPUv6+")
|
|
58
|
+
|
|
59
|
+
def _test_quantized_matmul(
|
|
60
|
+
self,
|
|
61
|
+
dtype: jnp.dtype,
|
|
62
|
+
q_dtype: jnp.dtype,
|
|
63
|
+
bs: int,
|
|
64
|
+
n_input_features: int,
|
|
65
|
+
n_output_features: int,
|
|
66
|
+
quantize_activation: bool,
|
|
67
|
+
tuned_value=None,
|
|
68
|
+
atol=0.5,
|
|
69
|
+
rtol=0.5,
|
|
70
|
+
):
|
|
71
|
+
|
|
72
|
+
prng_key = jax.random.key(1234)
|
|
73
|
+
k0, k1 = jax.random.split(prng_key, 2)
|
|
74
|
+
x = jax.random.uniform(k0, (bs, n_input_features),
|
|
75
|
+
dtype=dtype,
|
|
76
|
+
minval=0,
|
|
77
|
+
maxval=1)
|
|
78
|
+
w = jax.random.uniform(
|
|
79
|
+
k1,
|
|
80
|
+
(n_output_features, n_input_features),
|
|
81
|
+
dtype=dtype,
|
|
82
|
+
minval=-1,
|
|
83
|
+
maxval=1,
|
|
84
|
+
)
|
|
85
|
+
w_q, w_scale = quantize_tensor(w, q_dtype)
|
|
86
|
+
w_scale = jnp.squeeze(w_scale)
|
|
87
|
+
assert w_scale.shape == (n_output_features, )
|
|
88
|
+
|
|
89
|
+
x_q_dtype = w_q.dtype if quantize_activation else dtype
|
|
90
|
+
output = quantized_matmul_kernel(
|
|
91
|
+
x,
|
|
92
|
+
w_q,
|
|
93
|
+
w_scale,
|
|
94
|
+
x_q_dtype=x_q_dtype,
|
|
95
|
+
tuned_value=tuned_value,
|
|
96
|
+
)
|
|
97
|
+
expected = reference_quantized_matmul(
|
|
98
|
+
x, w_q, w_scale, quantize_activation=quantize_activation)
|
|
99
|
+
|
|
100
|
+
self.assertAllClose(output,
|
|
101
|
+
expected,
|
|
102
|
+
rtol=rtol,
|
|
103
|
+
atol=atol,
|
|
104
|
+
check_dtypes=True)
|
|
105
|
+
|
|
106
|
+
@parameterized.product(
|
|
107
|
+
dtype=[jnp.bfloat16, jnp.float32],
|
|
108
|
+
q_dtype=[jnp.int8, jnp.float8_e4m3fn],
|
|
109
|
+
bs=[128, 256, 512],
|
|
110
|
+
n_input_features=[128, 256, 512],
|
|
111
|
+
n_output_features=[128, 256, 512],
|
|
112
|
+
quantize_activation=[True],
|
|
113
|
+
)
|
|
114
|
+
def test_quantized_matmul_various_input_shapes(
|
|
115
|
+
self,
|
|
116
|
+
dtype: jnp.dtype,
|
|
117
|
+
q_dtype: jnp.dtype,
|
|
118
|
+
bs: int,
|
|
119
|
+
n_input_features: int,
|
|
120
|
+
n_output_features: int,
|
|
121
|
+
quantize_activation: bool,
|
|
122
|
+
):
|
|
123
|
+
self._test_quantized_matmul(
|
|
124
|
+
dtype,
|
|
125
|
+
q_dtype,
|
|
126
|
+
bs,
|
|
127
|
+
n_input_features,
|
|
128
|
+
n_output_features,
|
|
129
|
+
quantize_activation=quantize_activation,
|
|
130
|
+
tuned_value=None,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
@parameterized.product(
|
|
134
|
+
dtype=[jnp.bfloat16, jnp.float32],
|
|
135
|
+
q_dtype=[jnp.int8, jnp.float8_e4m3fn],
|
|
136
|
+
bs=[64, 192],
|
|
137
|
+
n_input_features=[64, 192],
|
|
138
|
+
n_output_features=[64, 192],
|
|
139
|
+
quantize_activation=[True],
|
|
140
|
+
)
|
|
141
|
+
def test_quantized_matmul_unaligned_input_shapes(
|
|
142
|
+
self,
|
|
143
|
+
dtype: jnp.dtype,
|
|
144
|
+
q_dtype: jnp.dtype,
|
|
145
|
+
bs: int,
|
|
146
|
+
n_input_features: int,
|
|
147
|
+
n_output_features: int,
|
|
148
|
+
quantize_activation: bool,
|
|
149
|
+
):
|
|
150
|
+
self._test_quantized_matmul(
|
|
151
|
+
dtype,
|
|
152
|
+
q_dtype,
|
|
153
|
+
bs,
|
|
154
|
+
n_input_features,
|
|
155
|
+
n_output_features,
|
|
156
|
+
quantize_activation=quantize_activation,
|
|
157
|
+
tuned_value=None,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
@parameterized.parameters(
|
|
161
|
+
(jnp.bfloat16, jnp.int8, 128, 1280, 8192, True),
|
|
162
|
+
(jnp.bfloat16, jnp.int8, 128, 28672, 4096, True),
|
|
163
|
+
(jnp.bfloat16, jnp.int8, 128, 4096, 14336, True),
|
|
164
|
+
(jnp.bfloat16, jnp.int8, 128, 4096, 4096, True),
|
|
165
|
+
(jnp.bfloat16, jnp.int8, 128, 6144, 4096, True),
|
|
166
|
+
(jnp.bfloat16, jnp.int8, 128, 7168, 8192, True),
|
|
167
|
+
(jnp.bfloat16, jnp.int8, 128, 8192, 1024, True),
|
|
168
|
+
(jnp.bfloat16, jnp.int8, 128, 8192, 3584, True),
|
|
169
|
+
)
|
|
170
|
+
def test_quantized_matmul_use_tuned_block_sizes(
|
|
171
|
+
self,
|
|
172
|
+
dtype: jnp.dtype,
|
|
173
|
+
q_dtype: jnp.dtype,
|
|
174
|
+
bs: int,
|
|
175
|
+
n_input_features: int,
|
|
176
|
+
n_output_features: int,
|
|
177
|
+
quantize_activation: bool,
|
|
178
|
+
):
|
|
179
|
+
self._test_quantized_matmul(
|
|
180
|
+
dtype,
|
|
181
|
+
q_dtype,
|
|
182
|
+
bs,
|
|
183
|
+
n_input_features,
|
|
184
|
+
n_output_features,
|
|
185
|
+
quantize_activation=quantize_activation,
|
|
186
|
+
tuned_value=None,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
if __name__ == "__main__":
|
|
191
|
+
absltest.main(testLoader=jtu.JaxTestLoader())
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import numpy as np
|
|
4
|
+
from absl.testing import parameterized
|
|
5
|
+
from jax._src import test_util as jtu
|
|
6
|
+
from jax.sharding import Mesh, NamedSharding
|
|
7
|
+
from jax.sharding import PartitionSpec as P
|
|
8
|
+
|
|
9
|
+
from tpu_inference.kernels.ragged_paged_attention.v2.ragged_kv_cache_update import \
|
|
10
|
+
kv_cache_update
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def kv_cache_update_ref(new_kv, slot_mapping, kv_cache):
|
|
14
|
+
"""Reference implementation of KV cache update."""
|
|
15
|
+
for i in range(slot_mapping.shape[1]):
|
|
16
|
+
start_idx, new_kv_idx, slice_len = slot_mapping[:, i]
|
|
17
|
+
kv_cache = kv_cache.at[start_idx:start_idx + slice_len].set(
|
|
18
|
+
new_kv[new_kv_idx:new_kv_idx + slice_len])
|
|
19
|
+
return kv_cache
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
23
|
+
class KVCacheUpdateTest(jtu.JaxTestCase):
|
|
24
|
+
|
|
25
|
+
def _generate_data(self, page_size, combined_kv_head_num, head_dim):
|
|
26
|
+
page_num = 20
|
|
27
|
+
padded_num_tokens = 128
|
|
28
|
+
prng_key = jax.random.key(1234)
|
|
29
|
+
kv_cache = jnp.zeros(
|
|
30
|
+
(page_num * page_size, combined_kv_head_num, head_dim),
|
|
31
|
+
dtype=jnp.bfloat16)
|
|
32
|
+
new_kv = jax.random.normal(
|
|
33
|
+
prng_key, (padded_num_tokens, combined_kv_head_num, head_dim),
|
|
34
|
+
dtype=jnp.bfloat16)
|
|
35
|
+
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
|
|
36
|
+
dtype=np.int32)
|
|
37
|
+
num_slices = jnp.array([len(slice_lens)], dtype=np.int32)
|
|
38
|
+
kv_cache_start_indices = np.array([
|
|
39
|
+
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
|
|
40
|
+
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
|
|
41
|
+
],
|
|
42
|
+
dtype=np.int32)
|
|
43
|
+
new_kv_cache_indices = np.concatenate(
|
|
44
|
+
[np.array([0], dtype=np.int32),
|
|
45
|
+
np.cumsum(slice_lens[:-1])])
|
|
46
|
+
slot_mapping_np = np.stack(
|
|
47
|
+
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
|
|
48
|
+
slot_mapping_np = np.transpose(slot_mapping_np)
|
|
49
|
+
slot_mapping = jnp.array(slot_mapping_np, dtype=jnp.int32)
|
|
50
|
+
return new_kv, slot_mapping, kv_cache, num_slices
|
|
51
|
+
|
|
52
|
+
@parameterized.product(
|
|
53
|
+
page_size=[32, 33],
|
|
54
|
+
combined_kv_head_num=[2, 16],
|
|
55
|
+
head_dim=[128, 256],
|
|
56
|
+
num_slices_per_block=[None, 8],
|
|
57
|
+
dynamic_validate_inputs=[False, True],
|
|
58
|
+
)
|
|
59
|
+
def test_basic(self, page_size: int, combined_kv_head_num: int,
|
|
60
|
+
head_dim: int, num_slices_per_block: int,
|
|
61
|
+
dynamic_validate_inputs: bool):
|
|
62
|
+
new_kv, slot_mapping, kv_cache, num_slices = self._generate_data(
|
|
63
|
+
page_size, combined_kv_head_num, head_dim)
|
|
64
|
+
old_kv_cache_copy = kv_cache.copy()
|
|
65
|
+
|
|
66
|
+
with jax.disable_jit(disable=dynamic_validate_inputs):
|
|
67
|
+
updated_kv_cache = kv_cache_update(
|
|
68
|
+
new_kv,
|
|
69
|
+
slot_mapping,
|
|
70
|
+
kv_cache,
|
|
71
|
+
num_slices,
|
|
72
|
+
page_size=page_size,
|
|
73
|
+
num_slices_per_block=num_slices_per_block,
|
|
74
|
+
dynamic_validate_inputs=dynamic_validate_inputs)
|
|
75
|
+
updated_kv_cache_ref = kv_cache_update_ref(new_kv,
|
|
76
|
+
np.asarray(slot_mapping),
|
|
77
|
+
old_kv_cache_copy)
|
|
78
|
+
self.assertAllClose(updated_kv_cache,
|
|
79
|
+
updated_kv_cache_ref,
|
|
80
|
+
atol=1e-4,
|
|
81
|
+
rtol=1e-4)
|
|
82
|
+
|
|
83
|
+
@parameterized.product(
|
|
84
|
+
page_size=[32, 33],
|
|
85
|
+
combined_kv_head_num=[16, 32],
|
|
86
|
+
head_dim=[128, 256],
|
|
87
|
+
num_slices_per_block=[None, 8],
|
|
88
|
+
)
|
|
89
|
+
def test_torchax_shard_map(self, page_size: int, combined_kv_head_num: int,
|
|
90
|
+
head_dim: int, num_slices_per_block: int):
|
|
91
|
+
new_kv, slot_mapping, kv_cache, num_slices = self._generate_data(
|
|
92
|
+
page_size, combined_kv_head_num, head_dim)
|
|
93
|
+
old_kv_cache_copy = kv_cache.copy()
|
|
94
|
+
|
|
95
|
+
mesh = Mesh(jax.devices(), 'x')
|
|
96
|
+
kv_cache_pspec = P(None, 'x', None)
|
|
97
|
+
|
|
98
|
+
new_kv = jax.device_put(new_kv, NamedSharding(mesh, kv_cache_pspec))
|
|
99
|
+
slot_mapping = jax.device_put(slot_mapping, NamedSharding(mesh, P()))
|
|
100
|
+
kv_cache = jax.device_put(kv_cache,
|
|
101
|
+
NamedSharding(mesh, kv_cache_pspec))
|
|
102
|
+
num_slices = jax.device_put(num_slices, NamedSharding(mesh, P()))
|
|
103
|
+
|
|
104
|
+
updated_kv_cache = kv_cache_update(new_kv, slot_mapping, kv_cache,
|
|
105
|
+
num_slices,
|
|
106
|
+
page_size=page_size,
|
|
107
|
+
num_slices_per_block=\
|
|
108
|
+
num_slices_per_block,
|
|
109
|
+
mesh=mesh,
|
|
110
|
+
kv_cache_pspec=kv_cache_pspec,)
|
|
111
|
+
updated_kv_cache_ref = kv_cache_update_ref(new_kv,
|
|
112
|
+
np.asarray(slot_mapping),
|
|
113
|
+
old_kv_cache_copy)
|
|
114
|
+
self.assertAllClose(updated_kv_cache,
|
|
115
|
+
updated_kv_cache_ref,
|
|
116
|
+
atol=1e-4,
|
|
117
|
+
rtol=1e-4)
|
|
118
|
+
|
|
119
|
+
def test_invalid_inputs(self):
|
|
120
|
+
# Test all the cases when the inputs are invalid in the `_dynamic_validate_inputs` method
|
|
121
|
+
page_size = 32
|
|
122
|
+
combined_kv_head_num = 2
|
|
123
|
+
head_dim = 128
|
|
124
|
+
|
|
125
|
+
new_kv, slot_mapping, kv_cache, num_slices = self._generate_data(
|
|
126
|
+
page_size, combined_kv_head_num, head_dim)
|
|
127
|
+
|
|
128
|
+
with jax.disable_jit():
|
|
129
|
+
# Case 1: new_kv_start < 0
|
|
130
|
+
invalid_slot_mapping = slot_mapping.at[1, 0].set(-1)
|
|
131
|
+
with self.assertRaisesRegex(
|
|
132
|
+
ValueError, "new_kv_start=-1 must be greater than"):
|
|
133
|
+
kv_cache_update(new_kv,
|
|
134
|
+
invalid_slot_mapping,
|
|
135
|
+
kv_cache,
|
|
136
|
+
num_slices,
|
|
137
|
+
page_size=page_size,
|
|
138
|
+
dynamic_validate_inputs=True)
|
|
139
|
+
|
|
140
|
+
# Case 2: kv_cache_start < 0
|
|
141
|
+
invalid_slot_mapping = slot_mapping.at[0, 0].set(-1)
|
|
142
|
+
with self.assertRaisesRegex(
|
|
143
|
+
ValueError, "kv_cache_start=-1 must be greater than"):
|
|
144
|
+
kv_cache_update(new_kv,
|
|
145
|
+
invalid_slot_mapping,
|
|
146
|
+
kv_cache,
|
|
147
|
+
num_slices,
|
|
148
|
+
page_size=page_size,
|
|
149
|
+
dynamic_validate_inputs=True)
|
|
150
|
+
|
|
151
|
+
# Case 3: slice_len <= 0
|
|
152
|
+
invalid_slot_mapping = slot_mapping.at[2, 0].set(0)
|
|
153
|
+
with self.assertRaisesRegex(
|
|
154
|
+
ValueError, "slice_len=0 must be less or equal to"):
|
|
155
|
+
kv_cache_update(new_kv,
|
|
156
|
+
invalid_slot_mapping,
|
|
157
|
+
kv_cache,
|
|
158
|
+
num_slices,
|
|
159
|
+
page_size=page_size,
|
|
160
|
+
dynamic_validate_inputs=True)
|
|
161
|
+
|
|
162
|
+
# Case 4: slice_len > page_size
|
|
163
|
+
invalid_slot_mapping = slot_mapping.at[2, 0].set(page_size + 1)
|
|
164
|
+
with self.assertRaisesRegex(
|
|
165
|
+
ValueError,
|
|
166
|
+
f"slice_len={page_size + 1} must be less or equal to"):
|
|
167
|
+
kv_cache_update(new_kv,
|
|
168
|
+
invalid_slot_mapping,
|
|
169
|
+
kv_cache,
|
|
170
|
+
num_slices,
|
|
171
|
+
page_size=page_size,
|
|
172
|
+
dynamic_validate_inputs=True)
|
|
173
|
+
|
|
174
|
+
# Case 5: new_kv_start + slice_len > new_token_num
|
|
175
|
+
invalid_slot_mapping = slot_mapping.at[1, 0].set(new_kv.shape[0])
|
|
176
|
+
with self.assertRaisesRegex(
|
|
177
|
+
ValueError,
|
|
178
|
+
"new_kv_start=128 \+ slice_len=7 must be less or equal to new_token_num=128"
|
|
179
|
+
):
|
|
180
|
+
kv_cache_update(new_kv,
|
|
181
|
+
invalid_slot_mapping,
|
|
182
|
+
kv_cache,
|
|
183
|
+
num_slices,
|
|
184
|
+
page_size=page_size,
|
|
185
|
+
dynamic_validate_inputs=True)
|
|
186
|
+
|
|
187
|
+
# Case 6: kv_cache_start + slice_len > kv_cache_token_num
|
|
188
|
+
invalid_slot_mapping = slot_mapping.at[0, 0].set(kv_cache.shape[0])
|
|
189
|
+
with self.assertRaisesRegex(
|
|
190
|
+
ValueError,
|
|
191
|
+
"kv_cache_start=640 \+ slice_len=7 must be less or equal to kv_cache_token_num=640"
|
|
192
|
+
):
|
|
193
|
+
kv_cache_update(new_kv,
|
|
194
|
+
invalid_slot_mapping,
|
|
195
|
+
kv_cache,
|
|
196
|
+
num_slices,
|
|
197
|
+
page_size=page_size,
|
|
198
|
+
dynamic_validate_inputs=True)
|
|
199
|
+
|
|
200
|
+
# Case 7: Each slice must reside in the same page
|
|
201
|
+
invalid_slot_mapping = slot_mapping.at[0, 0].set(page_size - 1)
|
|
202
|
+
invalid_slot_mapping = invalid_slot_mapping.at[2, 0].set(page_size)
|
|
203
|
+
with self.assertRaisesRegex(
|
|
204
|
+
ValueError, "Each slice must reside in the same page"):
|
|
205
|
+
kv_cache_update(new_kv,
|
|
206
|
+
invalid_slot_mapping,
|
|
207
|
+
kv_cache,
|
|
208
|
+
num_slices,
|
|
209
|
+
page_size=page_size,
|
|
210
|
+
dynamic_validate_inputs=True)
|
|
211
|
+
|
|
212
|
+
# Case 8: new_kv slices are not continuous
|
|
213
|
+
invalid_slot_mapping = slot_mapping.at[1,
|
|
214
|
+
1].set(slot_mapping[1, 1] +
|
|
215
|
+
1)
|
|
216
|
+
with self.assertRaisesRegex(ValueError, "is expeced to equal to"):
|
|
217
|
+
kv_cache_update(new_kv,
|
|
218
|
+
invalid_slot_mapping,
|
|
219
|
+
kv_cache,
|
|
220
|
+
num_slices,
|
|
221
|
+
page_size=page_size,
|
|
222
|
+
dynamic_validate_inputs=True)
|
|
223
|
+
|
|
224
|
+
# Case 9: Overlap among the kv cache slices
|
|
225
|
+
invalid_slot_mapping = slot_mapping.at[0, 4].set(slot_mapping[0,
|
|
226
|
+
3])
|
|
227
|
+
with self.assertRaisesRegex(
|
|
228
|
+
ValueError, "Overlap detected in kv_cache intervals"):
|
|
229
|
+
kv_cache_update(new_kv,
|
|
230
|
+
invalid_slot_mapping,
|
|
231
|
+
kv_cache,
|
|
232
|
+
num_slices,
|
|
233
|
+
page_size=page_size,
|
|
234
|
+
dynamic_validate_inputs=True)
|