tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,53 @@
1
+ import unittest
2
+
3
+ from tpu_inference.core.disagg_utils import _parse_slices
4
+
5
+
6
+ class DisaggUtilsTest(unittest.TestCase):
7
+
8
+ def test_parse_slices_valid_cases(self):
9
+ """Tests valid slice strings."""
10
+ # Test with a single slice
11
+ self.assertEqual(_parse_slices("2x2"), ((2, 2), ))
12
+ self.assertEqual(_parse_slices("2"), (2, ))
13
+
14
+ # Test with multiple slices
15
+ self.assertEqual(_parse_slices("2x2,2x1,3,2x4"),
16
+ ((2, 2), (2, 1), 3, (2, 4)))
17
+
18
+ # Test with various dimensions
19
+ self.assertEqual(_parse_slices("1x1,10x10,5x3"),
20
+ ((1, 1), (10, 10), (5, 3)))
21
+
22
+ # Test with an empty string
23
+ self.assertEqual(_parse_slices(""), ())
24
+
25
+ def test_parse_slices_with_whitespace(self):
26
+ """Tests valid slice strings with extra whitespace."""
27
+ self.assertEqual(_parse_slices(" 2x2 "), ((2, 2), ))
28
+ self.assertEqual(_parse_slices(" 2x2 , 2x1 , 2x4 "),
29
+ ((2, 2), (2, 1), (2, 4)))
30
+ # The current implementation allows spaces inside the slice definition
31
+ self.assertEqual(_parse_slices("2 x 2"), ((2, 2), ))
32
+ self.assertEqual(_parse_slices(" 10 x 10 "), ((10, 10), ))
33
+
34
+ def test_parse_slices_invalid_cases(self):
35
+ """Tests malformed slice strings that should raise ValueError."""
36
+ invalid_strings = [
37
+ "2*2", # wrong separator
38
+ "2x", # incomplete
39
+ "axb", # not integers
40
+ "2x2x2", # too many dimensions
41
+ "2x2,3*3", # partially malformed
42
+ ",2x2", # leading comma
43
+ "2x2,", # trailing comma
44
+ "2x2,,2x1", # empty slice in middle
45
+ ]
46
+ for invalid_str in invalid_strings:
47
+ with self.subTest(invalid_str=invalid_str):
48
+ with self.assertRaises(ValueError):
49
+ _parse_slices(invalid_str)
50
+
51
+
52
+ if __name__ == '__main__':
53
+ unittest.main()
@@ -0,0 +1,49 @@
1
+ import importlib
2
+ import unittest
3
+ from unittest.mock import patch
4
+
5
+
6
+ class TestPathwaysInit(unittest.TestCase):
7
+
8
+ @patch.dict("os.environ", {"JAX_PLATFORMS": "proxy,cpu"})
9
+ def test_VLLM_TPU_USING_PATHWAYS_enabled(self):
10
+ """Test when JAX_PLATFORMS contains 'proxy'."""
11
+
12
+ # Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
13
+ import vllm.envs as envs
14
+
15
+ # Reload the module to ensure fresh import
16
+ importlib.reload(envs)
17
+
18
+ # Check that VLLM_TPU_USING_PATHWAYS is True when JAX_PLATFORMS contains "proxy"
19
+ self.assertTrue(envs.VLLM_TPU_USING_PATHWAYS)
20
+
21
+ @patch.dict("os.environ", {"JAX_PLATFORMS": "cpu"})
22
+ def test_VLLM_TPU_USING_PATHWAYS_not_enabled(self):
23
+ """Test when JAX_PLATFORMS does not contain 'proxy'."""
24
+
25
+ # Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
26
+ import vllm.envs as envs
27
+
28
+ # Reload the module to ensure fresh import
29
+ importlib.reload(envs)
30
+
31
+ # Check that VLLM_TPU_USING_PATHWAYS is False when JAX_PLATFORMS doesn't contain "proxy"
32
+ self.assertFalse(envs.VLLM_TPU_USING_PATHWAYS)
33
+
34
+ @patch.dict("os.environ", {"JAX_PLATFORMS": "PROXY,CPU"})
35
+ def test_VLLM_TPU_USING_PATHWAYS_case_insensitive(self):
36
+ """Test that JAX_PLATFORMS check is case insensitive."""
37
+
38
+ # Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
39
+ import vllm.envs as envs
40
+
41
+ # Reload the module to ensure fresh import
42
+ importlib.reload(envs)
43
+
44
+ # Check that VLLM_TPU_USING_PATHWAYS is True even with uppercase "PROXY"
45
+ self.assertTrue(envs.VLLM_TPU_USING_PATHWAYS)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ unittest.main()
File without changes
@@ -0,0 +1,191 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import functools
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from absl.testing import absltest, parameterized
8
+ from jax._src import test_util as jtu
9
+
10
+ from tpu_inference.kernels.quantized_matmul import (kernel, tuned_block_sizes,
11
+ util)
12
+
13
+ quantized_matmul_kernel = kernel.quantized_matmul_kernel
14
+ quantize_tensor = util.quantize_tensor
15
+ get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
16
+
17
+ jax.config.parse_flags_with_absl()
18
+
19
+
20
+ @functools.partial(jax.jit, static_argnames=["quantize_activation"])
21
+ def reference_quantized_matmul(
22
+ x: jax.Array,
23
+ w_q: jax.Array,
24
+ w_scale: jax.Array,
25
+ quantize_activation=True,
26
+ ):
27
+ if quantize_activation:
28
+ acc_dtype = jnp.float32
29
+ if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
30
+ acc_dtype = jnp.int32
31
+
32
+ x_q, x_scale = quantize_tensor(x, w_q.dtype)
33
+ out = jax.lax.dot_general(
34
+ x_q,
35
+ w_q,
36
+ dimension_numbers=(((1, ), (1, )), ((), ())),
37
+ preferred_element_type=acc_dtype,
38
+ ).astype(jnp.float32)
39
+ out *= x_scale
40
+ else:
41
+ out = jax.lax.dot_general(
42
+ x,
43
+ w_q,
44
+ dimension_numbers=(((1, ), (1, )), ((), ())),
45
+ preferred_element_type=jnp.float32,
46
+ )
47
+ out *= jnp.expand_dims(w_scale, 0)
48
+ return out.astype(x.dtype)
49
+
50
+
51
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
52
+ class QuantizedMatmulKernelTest(jtu.JaxTestCase):
53
+
54
+ def setUp(self):
55
+ super().setUp()
56
+ if not jtu.is_device_tpu_at_least(6):
57
+ self.skipTest("Expect TPUv6+")
58
+
59
+ def _test_quantized_matmul(
60
+ self,
61
+ dtype: jnp.dtype,
62
+ q_dtype: jnp.dtype,
63
+ bs: int,
64
+ n_input_features: int,
65
+ n_output_features: int,
66
+ quantize_activation: bool,
67
+ tuned_value=None,
68
+ atol=0.5,
69
+ rtol=0.5,
70
+ ):
71
+
72
+ prng_key = jax.random.key(1234)
73
+ k0, k1 = jax.random.split(prng_key, 2)
74
+ x = jax.random.uniform(k0, (bs, n_input_features),
75
+ dtype=dtype,
76
+ minval=0,
77
+ maxval=1)
78
+ w = jax.random.uniform(
79
+ k1,
80
+ (n_output_features, n_input_features),
81
+ dtype=dtype,
82
+ minval=-1,
83
+ maxval=1,
84
+ )
85
+ w_q, w_scale = quantize_tensor(w, q_dtype)
86
+ w_scale = jnp.squeeze(w_scale)
87
+ assert w_scale.shape == (n_output_features, )
88
+
89
+ x_q_dtype = w_q.dtype if quantize_activation else dtype
90
+ output = quantized_matmul_kernel(
91
+ x,
92
+ w_q,
93
+ w_scale,
94
+ x_q_dtype=x_q_dtype,
95
+ tuned_value=tuned_value,
96
+ )
97
+ expected = reference_quantized_matmul(
98
+ x, w_q, w_scale, quantize_activation=quantize_activation)
99
+
100
+ self.assertAllClose(output,
101
+ expected,
102
+ rtol=rtol,
103
+ atol=atol,
104
+ check_dtypes=True)
105
+
106
+ @parameterized.product(
107
+ dtype=[jnp.bfloat16, jnp.float32],
108
+ q_dtype=[jnp.int8, jnp.float8_e4m3fn],
109
+ bs=[128, 256, 512],
110
+ n_input_features=[128, 256, 512],
111
+ n_output_features=[128, 256, 512],
112
+ quantize_activation=[True],
113
+ )
114
+ def test_quantized_matmul_various_input_shapes(
115
+ self,
116
+ dtype: jnp.dtype,
117
+ q_dtype: jnp.dtype,
118
+ bs: int,
119
+ n_input_features: int,
120
+ n_output_features: int,
121
+ quantize_activation: bool,
122
+ ):
123
+ self._test_quantized_matmul(
124
+ dtype,
125
+ q_dtype,
126
+ bs,
127
+ n_input_features,
128
+ n_output_features,
129
+ quantize_activation=quantize_activation,
130
+ tuned_value=None,
131
+ )
132
+
133
+ @parameterized.product(
134
+ dtype=[jnp.bfloat16, jnp.float32],
135
+ q_dtype=[jnp.int8, jnp.float8_e4m3fn],
136
+ bs=[64, 192],
137
+ n_input_features=[64, 192],
138
+ n_output_features=[64, 192],
139
+ quantize_activation=[True],
140
+ )
141
+ def test_quantized_matmul_unaligned_input_shapes(
142
+ self,
143
+ dtype: jnp.dtype,
144
+ q_dtype: jnp.dtype,
145
+ bs: int,
146
+ n_input_features: int,
147
+ n_output_features: int,
148
+ quantize_activation: bool,
149
+ ):
150
+ self._test_quantized_matmul(
151
+ dtype,
152
+ q_dtype,
153
+ bs,
154
+ n_input_features,
155
+ n_output_features,
156
+ quantize_activation=quantize_activation,
157
+ tuned_value=None,
158
+ )
159
+
160
+ @parameterized.parameters(
161
+ (jnp.bfloat16, jnp.int8, 128, 1280, 8192, True),
162
+ (jnp.bfloat16, jnp.int8, 128, 28672, 4096, True),
163
+ (jnp.bfloat16, jnp.int8, 128, 4096, 14336, True),
164
+ (jnp.bfloat16, jnp.int8, 128, 4096, 4096, True),
165
+ (jnp.bfloat16, jnp.int8, 128, 6144, 4096, True),
166
+ (jnp.bfloat16, jnp.int8, 128, 7168, 8192, True),
167
+ (jnp.bfloat16, jnp.int8, 128, 8192, 1024, True),
168
+ (jnp.bfloat16, jnp.int8, 128, 8192, 3584, True),
169
+ )
170
+ def test_quantized_matmul_use_tuned_block_sizes(
171
+ self,
172
+ dtype: jnp.dtype,
173
+ q_dtype: jnp.dtype,
174
+ bs: int,
175
+ n_input_features: int,
176
+ n_output_features: int,
177
+ quantize_activation: bool,
178
+ ):
179
+ self._test_quantized_matmul(
180
+ dtype,
181
+ q_dtype,
182
+ bs,
183
+ n_input_features,
184
+ n_output_features,
185
+ quantize_activation=quantize_activation,
186
+ tuned_value=None,
187
+ )
188
+
189
+
190
+ if __name__ == "__main__":
191
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -0,0 +1,234 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ from absl.testing import parameterized
5
+ from jax._src import test_util as jtu
6
+ from jax.sharding import Mesh, NamedSharding
7
+ from jax.sharding import PartitionSpec as P
8
+
9
+ from tpu_inference.kernels.ragged_paged_attention.v2.ragged_kv_cache_update import \
10
+ kv_cache_update
11
+
12
+
13
+ def kv_cache_update_ref(new_kv, slot_mapping, kv_cache):
14
+ """Reference implementation of KV cache update."""
15
+ for i in range(slot_mapping.shape[1]):
16
+ start_idx, new_kv_idx, slice_len = slot_mapping[:, i]
17
+ kv_cache = kv_cache.at[start_idx:start_idx + slice_len].set(
18
+ new_kv[new_kv_idx:new_kv_idx + slice_len])
19
+ return kv_cache
20
+
21
+
22
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
23
+ class KVCacheUpdateTest(jtu.JaxTestCase):
24
+
25
+ def _generate_data(self, page_size, combined_kv_head_num, head_dim):
26
+ page_num = 20
27
+ padded_num_tokens = 128
28
+ prng_key = jax.random.key(1234)
29
+ kv_cache = jnp.zeros(
30
+ (page_num * page_size, combined_kv_head_num, head_dim),
31
+ dtype=jnp.bfloat16)
32
+ new_kv = jax.random.normal(
33
+ prng_key, (padded_num_tokens, combined_kv_head_num, head_dim),
34
+ dtype=jnp.bfloat16)
35
+ slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
36
+ dtype=np.int32)
37
+ num_slices = jnp.array([len(slice_lens)], dtype=np.int32)
38
+ kv_cache_start_indices = np.array([
39
+ page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
40
+ page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
41
+ ],
42
+ dtype=np.int32)
43
+ new_kv_cache_indices = np.concatenate(
44
+ [np.array([0], dtype=np.int32),
45
+ np.cumsum(slice_lens[:-1])])
46
+ slot_mapping_np = np.stack(
47
+ [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
48
+ slot_mapping_np = np.transpose(slot_mapping_np)
49
+ slot_mapping = jnp.array(slot_mapping_np, dtype=jnp.int32)
50
+ return new_kv, slot_mapping, kv_cache, num_slices
51
+
52
+ @parameterized.product(
53
+ page_size=[32, 33],
54
+ combined_kv_head_num=[2, 16],
55
+ head_dim=[128, 256],
56
+ num_slices_per_block=[None, 8],
57
+ dynamic_validate_inputs=[False, True],
58
+ )
59
+ def test_basic(self, page_size: int, combined_kv_head_num: int,
60
+ head_dim: int, num_slices_per_block: int,
61
+ dynamic_validate_inputs: bool):
62
+ new_kv, slot_mapping, kv_cache, num_slices = self._generate_data(
63
+ page_size, combined_kv_head_num, head_dim)
64
+ old_kv_cache_copy = kv_cache.copy()
65
+
66
+ with jax.disable_jit(disable=dynamic_validate_inputs):
67
+ updated_kv_cache = kv_cache_update(
68
+ new_kv,
69
+ slot_mapping,
70
+ kv_cache,
71
+ num_slices,
72
+ page_size=page_size,
73
+ num_slices_per_block=num_slices_per_block,
74
+ dynamic_validate_inputs=dynamic_validate_inputs)
75
+ updated_kv_cache_ref = kv_cache_update_ref(new_kv,
76
+ np.asarray(slot_mapping),
77
+ old_kv_cache_copy)
78
+ self.assertAllClose(updated_kv_cache,
79
+ updated_kv_cache_ref,
80
+ atol=1e-4,
81
+ rtol=1e-4)
82
+
83
+ @parameterized.product(
84
+ page_size=[32, 33],
85
+ combined_kv_head_num=[16, 32],
86
+ head_dim=[128, 256],
87
+ num_slices_per_block=[None, 8],
88
+ )
89
+ def test_torchax_shard_map(self, page_size: int, combined_kv_head_num: int,
90
+ head_dim: int, num_slices_per_block: int):
91
+ new_kv, slot_mapping, kv_cache, num_slices = self._generate_data(
92
+ page_size, combined_kv_head_num, head_dim)
93
+ old_kv_cache_copy = kv_cache.copy()
94
+
95
+ mesh = Mesh(jax.devices(), 'x')
96
+ kv_cache_pspec = P(None, 'x', None)
97
+
98
+ new_kv = jax.device_put(new_kv, NamedSharding(mesh, kv_cache_pspec))
99
+ slot_mapping = jax.device_put(slot_mapping, NamedSharding(mesh, P()))
100
+ kv_cache = jax.device_put(kv_cache,
101
+ NamedSharding(mesh, kv_cache_pspec))
102
+ num_slices = jax.device_put(num_slices, NamedSharding(mesh, P()))
103
+
104
+ updated_kv_cache = kv_cache_update(new_kv, slot_mapping, kv_cache,
105
+ num_slices,
106
+ page_size=page_size,
107
+ num_slices_per_block=\
108
+ num_slices_per_block,
109
+ mesh=mesh,
110
+ kv_cache_pspec=kv_cache_pspec,)
111
+ updated_kv_cache_ref = kv_cache_update_ref(new_kv,
112
+ np.asarray(slot_mapping),
113
+ old_kv_cache_copy)
114
+ self.assertAllClose(updated_kv_cache,
115
+ updated_kv_cache_ref,
116
+ atol=1e-4,
117
+ rtol=1e-4)
118
+
119
+ def test_invalid_inputs(self):
120
+ # Test all the cases when the inputs are invalid in the `_dynamic_validate_inputs` method
121
+ page_size = 32
122
+ combined_kv_head_num = 2
123
+ head_dim = 128
124
+
125
+ new_kv, slot_mapping, kv_cache, num_slices = self._generate_data(
126
+ page_size, combined_kv_head_num, head_dim)
127
+
128
+ with jax.disable_jit():
129
+ # Case 1: new_kv_start < 0
130
+ invalid_slot_mapping = slot_mapping.at[1, 0].set(-1)
131
+ with self.assertRaisesRegex(
132
+ ValueError, "new_kv_start=-1 must be greater than"):
133
+ kv_cache_update(new_kv,
134
+ invalid_slot_mapping,
135
+ kv_cache,
136
+ num_slices,
137
+ page_size=page_size,
138
+ dynamic_validate_inputs=True)
139
+
140
+ # Case 2: kv_cache_start < 0
141
+ invalid_slot_mapping = slot_mapping.at[0, 0].set(-1)
142
+ with self.assertRaisesRegex(
143
+ ValueError, "kv_cache_start=-1 must be greater than"):
144
+ kv_cache_update(new_kv,
145
+ invalid_slot_mapping,
146
+ kv_cache,
147
+ num_slices,
148
+ page_size=page_size,
149
+ dynamic_validate_inputs=True)
150
+
151
+ # Case 3: slice_len <= 0
152
+ invalid_slot_mapping = slot_mapping.at[2, 0].set(0)
153
+ with self.assertRaisesRegex(
154
+ ValueError, "slice_len=0 must be less or equal to"):
155
+ kv_cache_update(new_kv,
156
+ invalid_slot_mapping,
157
+ kv_cache,
158
+ num_slices,
159
+ page_size=page_size,
160
+ dynamic_validate_inputs=True)
161
+
162
+ # Case 4: slice_len > page_size
163
+ invalid_slot_mapping = slot_mapping.at[2, 0].set(page_size + 1)
164
+ with self.assertRaisesRegex(
165
+ ValueError,
166
+ f"slice_len={page_size + 1} must be less or equal to"):
167
+ kv_cache_update(new_kv,
168
+ invalid_slot_mapping,
169
+ kv_cache,
170
+ num_slices,
171
+ page_size=page_size,
172
+ dynamic_validate_inputs=True)
173
+
174
+ # Case 5: new_kv_start + slice_len > new_token_num
175
+ invalid_slot_mapping = slot_mapping.at[1, 0].set(new_kv.shape[0])
176
+ with self.assertRaisesRegex(
177
+ ValueError,
178
+ "new_kv_start=128 \+ slice_len=7 must be less or equal to new_token_num=128"
179
+ ):
180
+ kv_cache_update(new_kv,
181
+ invalid_slot_mapping,
182
+ kv_cache,
183
+ num_slices,
184
+ page_size=page_size,
185
+ dynamic_validate_inputs=True)
186
+
187
+ # Case 6: kv_cache_start + slice_len > kv_cache_token_num
188
+ invalid_slot_mapping = slot_mapping.at[0, 0].set(kv_cache.shape[0])
189
+ with self.assertRaisesRegex(
190
+ ValueError,
191
+ "kv_cache_start=640 \+ slice_len=7 must be less or equal to kv_cache_token_num=640"
192
+ ):
193
+ kv_cache_update(new_kv,
194
+ invalid_slot_mapping,
195
+ kv_cache,
196
+ num_slices,
197
+ page_size=page_size,
198
+ dynamic_validate_inputs=True)
199
+
200
+ # Case 7: Each slice must reside in the same page
201
+ invalid_slot_mapping = slot_mapping.at[0, 0].set(page_size - 1)
202
+ invalid_slot_mapping = invalid_slot_mapping.at[2, 0].set(page_size)
203
+ with self.assertRaisesRegex(
204
+ ValueError, "Each slice must reside in the same page"):
205
+ kv_cache_update(new_kv,
206
+ invalid_slot_mapping,
207
+ kv_cache,
208
+ num_slices,
209
+ page_size=page_size,
210
+ dynamic_validate_inputs=True)
211
+
212
+ # Case 8: new_kv slices are not continuous
213
+ invalid_slot_mapping = slot_mapping.at[1,
214
+ 1].set(slot_mapping[1, 1] +
215
+ 1)
216
+ with self.assertRaisesRegex(ValueError, "is expeced to equal to"):
217
+ kv_cache_update(new_kv,
218
+ invalid_slot_mapping,
219
+ kv_cache,
220
+ num_slices,
221
+ page_size=page_size,
222
+ dynamic_validate_inputs=True)
223
+
224
+ # Case 9: Overlap among the kv cache slices
225
+ invalid_slot_mapping = slot_mapping.at[0, 4].set(slot_mapping[0,
226
+ 3])
227
+ with self.assertRaisesRegex(
228
+ ValueError, "Overlap detected in kv_cache intervals"):
229
+ kv_cache_update(new_kv,
230
+ invalid_slot_mapping,
231
+ kv_cache,
232
+ num_slices,
233
+ page_size=page_size,
234
+ dynamic_validate_inputs=True)