tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
import functools
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
from jax._src import dtypes
|
|
8
|
+
from jax.experimental import pallas as pl
|
|
9
|
+
from jax.experimental.pallas import tpu as pltpu
|
|
10
|
+
from jax.sharding import Mesh
|
|
11
|
+
from jax.sharding import PartitionSpec as P
|
|
12
|
+
|
|
13
|
+
from tpu_inference.utils import TPU_HEAD_SIZE_ALIGNMENT, get_dtype_packing
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _ceil_div(a, b):
|
|
17
|
+
assert b != 0
|
|
18
|
+
return (a + b - 1) // b
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _kv_cache_update_kernel(
|
|
22
|
+
# Prefetch
|
|
23
|
+
slices_ref, # [3, padded_num_slices], list of (kv_cache_start, new_kv_start,
|
|
24
|
+
# slice_len)
|
|
25
|
+
num_slices_ref, # [1]
|
|
26
|
+
# Input
|
|
27
|
+
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
|
28
|
+
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
|
29
|
+
# head_dim]
|
|
30
|
+
# Output
|
|
31
|
+
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
|
32
|
+
# Scratch
|
|
33
|
+
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
|
|
34
|
+
# head_dim]
|
|
35
|
+
sem,
|
|
36
|
+
):
|
|
37
|
+
async_copies = []
|
|
38
|
+
block_idx = pl.program_id(0)
|
|
39
|
+
num_slices_per_block = scratch.shape[0]
|
|
40
|
+
|
|
41
|
+
# Copy from new_kv_hbm_ref to scratch
|
|
42
|
+
for i in range(num_slices_per_block):
|
|
43
|
+
offset_i = i + block_idx * num_slices_per_block
|
|
44
|
+
new_kv_start = jax.lax.select(offset_i < num_slices_ref[0],
|
|
45
|
+
slices_ref[1, offset_i], 0)
|
|
46
|
+
length = jax.lax.select(offset_i < num_slices_ref[0],
|
|
47
|
+
slices_ref[2, offset_i], 0)
|
|
48
|
+
async_copy = pltpu.make_async_copy(
|
|
49
|
+
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
|
|
50
|
+
scratch.at[i, pl.ds(0, length), ...],
|
|
51
|
+
sem,
|
|
52
|
+
)
|
|
53
|
+
async_copy.start()
|
|
54
|
+
async_copies.append(async_copy)
|
|
55
|
+
|
|
56
|
+
for async_copy in async_copies:
|
|
57
|
+
async_copy.wait()
|
|
58
|
+
|
|
59
|
+
# Copy from scratch to kv_cache_hbm_ref
|
|
60
|
+
async_copies.clear()
|
|
61
|
+
for i in range(num_slices_per_block):
|
|
62
|
+
offset_i = i + block_idx * num_slices_per_block
|
|
63
|
+
kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0],
|
|
64
|
+
slices_ref[0, offset_i], 0)
|
|
65
|
+
length = jax.lax.select(offset_i < num_slices_ref[0],
|
|
66
|
+
slices_ref[2, offset_i], 0)
|
|
67
|
+
async_copy = pltpu.make_async_copy(
|
|
68
|
+
scratch.at[i, pl.ds(0, length), ...],
|
|
69
|
+
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
|
|
70
|
+
sem,
|
|
71
|
+
)
|
|
72
|
+
async_copy.start()
|
|
73
|
+
async_copies.append(async_copy)
|
|
74
|
+
for async_copy in async_copies:
|
|
75
|
+
async_copy.wait()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _dynamic_validate_inputs(slices, new_token_num, kv_cache_token_num,
|
|
79
|
+
page_size, num_slices):
|
|
80
|
+
slices = slices.tolist()
|
|
81
|
+
# NOTE: The padding part is unnecessary to check because kv_cache_start, new_kv_start,
|
|
82
|
+
# slice_len will be set to 0 in the kernel implementation.
|
|
83
|
+
for i in range(num_slices[0]):
|
|
84
|
+
kv_cache_start = slices[0][i]
|
|
85
|
+
new_kv_start = slices[1][i]
|
|
86
|
+
slice_len = slices[2][i]
|
|
87
|
+
if new_kv_start < 0:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"{new_kv_start=} must be greater than or equal to 0")
|
|
90
|
+
if kv_cache_start < 0:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"{kv_cache_start=} must be greater than or equal to 0")
|
|
93
|
+
if not 0 < slice_len <= page_size:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"{slice_len=} must be less or equal to {page_size=} and greater than 0"
|
|
96
|
+
)
|
|
97
|
+
if new_kv_start + slice_len > new_token_num:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"{new_kv_start=} + {slice_len=} must be less or equal to {new_token_num=}"
|
|
100
|
+
)
|
|
101
|
+
if kv_cache_start + slice_len > kv_cache_token_num:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"{kv_cache_start=} + {slice_len=} must be less or equal to {kv_cache_token_num=}"
|
|
104
|
+
)
|
|
105
|
+
if kv_cache_start // page_size != (kv_cache_start + slice_len -
|
|
106
|
+
1) // page_size:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Each slice must reside in the same page, but got {kv_cache_start=} and {slice_len=}"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
new_kv_intervals = []
|
|
112
|
+
kv_cache_intervals = []
|
|
113
|
+
for i in range(num_slices[0]):
|
|
114
|
+
new_kv_intervals.append((slices[1][i], slices[1][i] + slices[2][i]))
|
|
115
|
+
kv_cache_intervals.append((slices[0][i], slices[0][i] + slices[2][i]))
|
|
116
|
+
|
|
117
|
+
new_kv_intervals.sort()
|
|
118
|
+
kv_cache_intervals.sort()
|
|
119
|
+
|
|
120
|
+
# The new_kv slices should be continuous
|
|
121
|
+
for i in range(len(new_kv_intervals) - 1):
|
|
122
|
+
if new_kv_intervals[i][1] != new_kv_intervals[i + 1][0]:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"{new_kv_intervals[i][1]=} is expeced to equal to {new_kv_intervals[i + 1][0]}"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# There should be no overlap among the kv cache slices
|
|
128
|
+
for i in range(len(kv_cache_intervals) - 1):
|
|
129
|
+
if kv_cache_intervals[i][1] > kv_cache_intervals[i + 1][0]:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"Overlap detected in kv_cache intervals: {kv_cache_intervals[i]} and {kv_cache_intervals[i+1]}"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _kv_cache_update(
|
|
136
|
+
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
|
|
137
|
+
slices: jax.Array, # [3, slices], list of (kv_cache_start, new_kv_start,
|
|
138
|
+
# slice_len)
|
|
139
|
+
kv_cache: jax.
|
|
140
|
+
Array, # [total_num_pages * page_size, num_combined_kv_heads,
|
|
141
|
+
# head_dim]
|
|
142
|
+
num_slices: jax.Array, # [1]
|
|
143
|
+
page_size: int,
|
|
144
|
+
num_slices_per_block: int,
|
|
145
|
+
dynamic_validate_inputs: bool,
|
|
146
|
+
vmem_limit_bytes: int = 40 * 1024 * 1024,
|
|
147
|
+
):
|
|
148
|
+
new_token_num, num_combined_kv_heads, head_dim = new_kv.shape
|
|
149
|
+
assert kv_cache.shape[1] == num_combined_kv_heads
|
|
150
|
+
assert kv_cache.shape[2] == head_dim
|
|
151
|
+
assert head_dim % 128 == 0
|
|
152
|
+
if dynamic_validate_inputs is True:
|
|
153
|
+
_dynamic_validate_inputs(slices, new_token_num, kv_cache.shape[0],
|
|
154
|
+
page_size, num_slices)
|
|
155
|
+
|
|
156
|
+
in_specs = [
|
|
157
|
+
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
|
158
|
+
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
|
|
162
|
+
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
|
|
163
|
+
|
|
164
|
+
scalar_prefetches = [slices, num_slices]
|
|
165
|
+
scratch = pltpu.VMEM(
|
|
166
|
+
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
|
|
167
|
+
new_kv.dtype,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
scratch_shapes = [
|
|
171
|
+
scratch,
|
|
172
|
+
pltpu.SemaphoreType.DMA,
|
|
173
|
+
]
|
|
174
|
+
|
|
175
|
+
kernel = pl.pallas_call(
|
|
176
|
+
_kv_cache_update_kernel,
|
|
177
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
178
|
+
num_scalar_prefetch=len(scalar_prefetches),
|
|
179
|
+
in_specs=in_specs,
|
|
180
|
+
out_specs=out_specs,
|
|
181
|
+
grid=(_ceil_div(num_slices[0], num_slices_per_block), ),
|
|
182
|
+
scratch_shapes=scratch_shapes,
|
|
183
|
+
),
|
|
184
|
+
out_shape=out_shape,
|
|
185
|
+
input_output_aliases={len(scalar_prefetches) + 1: 0},
|
|
186
|
+
compiler_params=pltpu.CompilerParams(
|
|
187
|
+
vmem_limit_bytes=vmem_limit_bytes, ),
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _prev_power_of_2(n: int) -> int:
|
|
194
|
+
"""The previous power of 2 (inclusive)"""
|
|
195
|
+
if n <= 0:
|
|
196
|
+
return 0
|
|
197
|
+
return 1 << (n.bit_length() - 1)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _get_page_size_bytes(block_size: int, num_combined_kv_heads: int,
|
|
201
|
+
head_size: int, kv_cache_dtype) -> int:
|
|
202
|
+
"""Returns the size in bytes of one page of the KV cache."""
|
|
203
|
+
kv_cache_dtype_bit_size = dtypes.bit_width(kv_cache_dtype)
|
|
204
|
+
padded_head_size = _ceil_div(
|
|
205
|
+
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
|
206
|
+
|
|
207
|
+
# NOTE: for the implicit padding in XLA
|
|
208
|
+
packing = get_dtype_packing(kv_cache_dtype)
|
|
209
|
+
num_combined_kv_heads = _ceil_div(num_combined_kv_heads, packing) * packing
|
|
210
|
+
|
|
211
|
+
return block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bit_size // 8
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int,
|
|
215
|
+
vmem_limit_bytes: int) -> int:
|
|
216
|
+
"""Find the optimum number of slices to copy per Pallas program instance.
|
|
217
|
+
Increasing the number of slices copied in one instance of the kernel program
|
|
218
|
+
will increase HBM bandwidth utilization via more in-flight DMAs.
|
|
219
|
+
However, it will also use more VMEM, and experimentally, we observed
|
|
220
|
+
performance regression at 128 slices on v6e, likely due to running
|
|
221
|
+
out of scalar registers. Thus this function will limit the number of
|
|
222
|
+
slices to 64.
|
|
223
|
+
"""
|
|
224
|
+
# NOTE: We assume 1MB vmem is used for register spill and others
|
|
225
|
+
assert vmem_limit_bytes >= 1024 * 1024, "vmem_limit_bytes must be at least 1MB"
|
|
226
|
+
num_slices_per_block = (vmem_limit_bytes - 1024 * 1024) // page_size_bytes
|
|
227
|
+
assert num_slices_per_block > 0, "Number of slices should be positive"
|
|
228
|
+
num_slices_per_block = _prev_power_of_2(num_slices_per_block)
|
|
229
|
+
return min(num_slices_per_block, 64)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@functools.partial(
|
|
233
|
+
jax.jit,
|
|
234
|
+
static_argnames=[
|
|
235
|
+
"page_size", "num_slices_per_block", "mesh", "kv_cache_pspec"
|
|
236
|
+
],
|
|
237
|
+
donate_argnames="kv_cache",
|
|
238
|
+
)
|
|
239
|
+
def kv_cache_update(
|
|
240
|
+
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
|
|
241
|
+
slices: jax.
|
|
242
|
+
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
|
243
|
+
kv_cache: jax.
|
|
244
|
+
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
|
245
|
+
num_slices: jax.Array, # [1]
|
|
246
|
+
*,
|
|
247
|
+
page_size: int = 32,
|
|
248
|
+
num_slices_per_block: int | None = None,
|
|
249
|
+
mesh: Mesh | None = None,
|
|
250
|
+
kv_cache_pspec: P
|
|
251
|
+
| None = None, # Only sharding along head_dim is supported
|
|
252
|
+
dynamic_validate_inputs: bool = False,
|
|
253
|
+
vmem_limit_bytes: int = 40 * 1024 * 1024,
|
|
254
|
+
):
|
|
255
|
+
if num_slices_per_block is None:
|
|
256
|
+
_, num_combined_kv_heads, head_dim = new_kv.shape
|
|
257
|
+
page_size_bytes = _get_page_size_bytes(page_size,
|
|
258
|
+
num_combined_kv_heads, head_dim,
|
|
259
|
+
kv_cache.dtype)
|
|
260
|
+
num_slices_per_block = _get_num_slices_per_kv_cache_update_block(
|
|
261
|
+
page_size_bytes, vmem_limit_bytes)
|
|
262
|
+
|
|
263
|
+
if mesh is None:
|
|
264
|
+
return _kv_cache_update(new_kv, slices, kv_cache, num_slices,
|
|
265
|
+
page_size, num_slices_per_block,
|
|
266
|
+
dynamic_validate_inputs)
|
|
267
|
+
|
|
268
|
+
if kv_cache_pspec is None:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
"kv_cache_pspec must be provided when mesh is specified")
|
|
271
|
+
|
|
272
|
+
in_specs = (kv_cache_pspec, P(), kv_cache_pspec, P())
|
|
273
|
+
out_specs = kv_cache_pspec
|
|
274
|
+
shard_map_wrapped = jax.shard_map(
|
|
275
|
+
functools.partial(
|
|
276
|
+
_kv_cache_update,
|
|
277
|
+
page_size=page_size,
|
|
278
|
+
num_slices_per_block=num_slices_per_block,
|
|
279
|
+
dynamic_validate_inputs=dynamic_validate_inputs,
|
|
280
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
281
|
+
),
|
|
282
|
+
mesh=mesh,
|
|
283
|
+
in_specs=in_specs,
|
|
284
|
+
out_specs=out_specs,
|
|
285
|
+
check_vma=False,
|
|
286
|
+
)
|
|
287
|
+
return shard_map_wrapped(new_kv, slices, kv_cache, num_slices)
|