tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""Utility functions for ragged paged attention."""
|
|
2
15
|
import jax
|
|
3
16
|
from jax._src import dtypes
|
|
@@ -13,7 +26,8 @@ def align_to(x, a):
|
|
|
13
26
|
|
|
14
27
|
|
|
15
28
|
def get_dtype_bitwidth(dtype):
|
|
16
|
-
return dtypes.bit_width(dtype)
|
|
29
|
+
return (dtypes.bit_width(dtype)
|
|
30
|
+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
|
|
17
31
|
|
|
18
32
|
|
|
19
33
|
def get_dtype_packing(dtype):
|
tpu_inference/layers/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,10 +1,23 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
import math
|
|
3
17
|
from typing import Any, Callable, Optional, Tuple
|
|
4
18
|
|
|
5
19
|
import jax
|
|
6
20
|
import jax.numpy as jnp
|
|
7
|
-
from jax.experimental import shard_map
|
|
8
21
|
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
|
|
9
22
|
from jax.experimental.pallas.ops.tpu.splash_attention import \
|
|
10
23
|
splash_attention_kernel as splash
|
|
@@ -55,11 +68,11 @@ def sharded_flash_attention(
|
|
|
55
68
|
vmem_limit_bytes=vmem_limit_bytes)
|
|
56
69
|
|
|
57
70
|
return jax.jit(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
71
|
+
jax.shard_map(_flash_attention,
|
|
72
|
+
mesh=mesh,
|
|
73
|
+
in_specs=in_specs,
|
|
74
|
+
out_specs=out_specs,
|
|
75
|
+
check_vma=False))
|
|
63
76
|
|
|
64
77
|
|
|
65
78
|
def sharded_paged_attention(
|
|
@@ -94,12 +107,12 @@ def sharded_paged_attention(
|
|
|
94
107
|
)
|
|
95
108
|
|
|
96
109
|
return jax.jit(
|
|
97
|
-
|
|
110
|
+
jax.shard_map(
|
|
98
111
|
_paged_attention_fn,
|
|
99
112
|
mesh=mesh,
|
|
100
113
|
in_specs=in_specs,
|
|
101
114
|
out_specs=out_specs,
|
|
102
|
-
|
|
115
|
+
check_vma=False,
|
|
103
116
|
))
|
|
104
117
|
|
|
105
118
|
|
|
@@ -257,7 +270,7 @@ def sharded_splash_attention(
|
|
|
257
270
|
)
|
|
258
271
|
out_specs = P("data", "model", None, None)
|
|
259
272
|
return jax.jit(
|
|
260
|
-
|
|
273
|
+
jax.shard_map(
|
|
261
274
|
functools.partial(
|
|
262
275
|
apply_splash,
|
|
263
276
|
window_size=window_size,
|
|
@@ -267,7 +280,7 @@ def sharded_splash_attention(
|
|
|
267
280
|
mesh=mesh,
|
|
268
281
|
in_specs=in_specs,
|
|
269
282
|
out_specs=out_specs,
|
|
270
|
-
|
|
283
|
+
check_vma=False,
|
|
271
284
|
))
|
|
272
285
|
|
|
273
286
|
|
|
@@ -308,13 +321,7 @@ def sharded_ragged_paged_attention(
|
|
|
308
321
|
args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
|
|
309
322
|
|
|
310
323
|
use_hd64 = q.shape[-1] == 64
|
|
311
|
-
|
|
312
|
-
func = ragged_paged_attention
|
|
313
|
-
if use_hd64:
|
|
314
|
-
func = functools.partial(ragged_paged_attention_hd64,
|
|
315
|
-
strict_sliding_window=False)
|
|
316
|
-
else:
|
|
317
|
-
func = ragged_paged_attention
|
|
324
|
+
func = ragged_paged_attention_hd64 if use_hd64 else ragged_paged_attention
|
|
318
325
|
|
|
319
326
|
if attention_sink is not None:
|
|
320
327
|
if not use_hd64:
|
|
@@ -334,12 +341,12 @@ def sharded_ragged_paged_attention(
|
|
|
334
341
|
v_scale=v_scale,
|
|
335
342
|
)
|
|
336
343
|
|
|
337
|
-
return
|
|
344
|
+
return jax.shard_map(
|
|
338
345
|
_ragged_paged_attention,
|
|
339
346
|
mesh=mesh,
|
|
340
347
|
in_specs=in_specs,
|
|
341
348
|
out_specs=out_specs,
|
|
342
|
-
|
|
349
|
+
check_vma=False,
|
|
343
350
|
)(*args)
|
|
344
351
|
|
|
345
352
|
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
from dataclasses import dataclass, field
|
|
3
17
|
from typing import Any
|
|
@@ -1,7 +1,22 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
UNQUANTIZED = "unquantized"
|
|
2
16
|
MXFP4 = "mxfp4"
|
|
3
17
|
AWQ = "awq"
|
|
4
18
|
COMPRESSED_TENSORS = "compressed-tensors"
|
|
19
|
+
FP8 = "fp8"
|
|
5
20
|
|
|
6
21
|
|
|
7
22
|
def get_tpu_quant_method(quant_method: str) -> str:
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import itertools
|
|
16
|
+
from typing import Tuple
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
|
|
21
|
+
MXFP4_BLOCK_SIZE = 32
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def quantize_tensor_to_mxfp4_packed(
|
|
25
|
+
tensor: jax.Array,
|
|
26
|
+
axis: int | tuple = -1,
|
|
27
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
28
|
+
"""Quantize a tensor to mxfp4 and pack it into uint8."""
|
|
29
|
+
|
|
30
|
+
# Perform regular block quantization.
|
|
31
|
+
tensor_q, scale = quantize_tensor(
|
|
32
|
+
jnp.float4_e2m1fn,
|
|
33
|
+
tensor,
|
|
34
|
+
axis,
|
|
35
|
+
MXFP4_BLOCK_SIZE,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# last two e2m1 elements will be packed into a single uint8 element.
|
|
39
|
+
bitcast_shape = tensor_q.shape[:-1] + (-1, 2)
|
|
40
|
+
tensor_q = tensor_q.reshape(bitcast_shape)
|
|
41
|
+
tensor_q_packed = jax.lax.bitcast_convert_type(tensor_q, jnp.uint8)
|
|
42
|
+
|
|
43
|
+
# Since TPU does not have native support for e8m0, we convert scale into
|
|
44
|
+
# e8m0 manually and store it as uint8.
|
|
45
|
+
e8m0_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
46
|
+
_, scale_exp = jnp.frexp(scale)
|
|
47
|
+
# Subtract exponents by one since e8m0 has no decimal.
|
|
48
|
+
scale_exp -= 1
|
|
49
|
+
scale_exp = (scale_exp - e8m0_finfo.minexp).astype(jnp.uint8)
|
|
50
|
+
|
|
51
|
+
return tensor_q_packed, scale_exp
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
55
|
+
"""Unpack e2m1 tensor packed into u8."""
|
|
56
|
+
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
57
|
+
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
58
|
+
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
59
|
+
# we flatten them with the last dim.
|
|
60
|
+
return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
64
|
+
"""Convert e8m0 (that was bitcasted to u8) into fp32"""
|
|
65
|
+
assert u8.dtype == jnp.uint8
|
|
66
|
+
|
|
67
|
+
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
68
|
+
exponents = u8.astype(jnp.int32) + e8_finfo.minexp
|
|
69
|
+
ones = jnp.ones_like(u8, dtype=jnp.float32)
|
|
70
|
+
return jnp.ldexp(ones, exponents)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def dequantize_tensor(
|
|
74
|
+
tensor_q: jax.Array,
|
|
75
|
+
scale: jax.Array,
|
|
76
|
+
axis: int | None | tuple = -1,
|
|
77
|
+
out_dtype: jnp.dtype = jnp.bfloat16,
|
|
78
|
+
) -> jax.Array:
|
|
79
|
+
"""Dequantize a quantized tensor
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
tensor_q: Quantized tensor.
|
|
83
|
+
scale: Quantization scale.
|
|
84
|
+
axis: The axis tensor was quantized. None denotes per-tensor.
|
|
85
|
+
out_dtype: Dtype of the output.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Dequantized tensor_q.
|
|
89
|
+
"""
|
|
90
|
+
if axis is None:
|
|
91
|
+
# Perform per-tensor quantization.
|
|
92
|
+
axis = [i for i in range(tensor_q.ndim)]
|
|
93
|
+
if isinstance(axis, int):
|
|
94
|
+
axis = [axis]
|
|
95
|
+
|
|
96
|
+
orig_shape = tensor_q.shape
|
|
97
|
+
if tensor_q.ndim == scale.ndim:
|
|
98
|
+
# Indicates the tensor was block quantized.
|
|
99
|
+
blocked_shape = [[i] for i in orig_shape]
|
|
100
|
+
for i in axis:
|
|
101
|
+
num_blocks = scale.shape[i]
|
|
102
|
+
if tensor_q.shape[i] % num_blocks:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"Unable to perform block dequantization. axis={i} of "
|
|
105
|
+
f"{tensor_q.shape=} is not divisible by {num_blocks=}", )
|
|
106
|
+
block_size = tensor_q.shape[i] // num_blocks
|
|
107
|
+
|
|
108
|
+
blocked_shape[i] = (num_blocks, block_size)
|
|
109
|
+
|
|
110
|
+
# Convert all axis into positive values.
|
|
111
|
+
axis = sorted([(i + tensor_q.ndim) % tensor_q.ndim for i in axis])
|
|
112
|
+
# Shift axis by 1 since its original position is now occupied by
|
|
113
|
+
# num_blocks dim. Also, if n axes before an axis was also quantized,
|
|
114
|
+
# shift its position by n.
|
|
115
|
+
axis = [1 + n + i for n, i in enumerate(axis)]
|
|
116
|
+
|
|
117
|
+
# Flatten list of lists that contains (num_blocks, block).
|
|
118
|
+
blocked_shape = list(itertools.chain(*blocked_shape))
|
|
119
|
+
tensor_q = tensor_q.reshape(blocked_shape)
|
|
120
|
+
|
|
121
|
+
scale = jnp.expand_dims(scale, axis)
|
|
122
|
+
|
|
123
|
+
tensor = (tensor_q.astype(jnp.float32) * scale).astype(out_dtype)
|
|
124
|
+
|
|
125
|
+
return tensor.reshape(orig_shape)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def dequantize_tensor_from_mxfp4_packed(
|
|
129
|
+
tensor_q: jax.Array,
|
|
130
|
+
scale: jax.Array,
|
|
131
|
+
axis: int | tuple = -1,
|
|
132
|
+
out_dtype: jnp.dtype = jnp.bfloat16,
|
|
133
|
+
) -> jax.Array:
|
|
134
|
+
"""Dequantize packed mxfp4 tensor.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
tensor_q: fp4 tensor packed into uint8.
|
|
138
|
+
scale: e8m0 scale packed into uint8.
|
|
139
|
+
axis: The axis tensor was quantized.
|
|
140
|
+
out_dtype: Dtype of the output.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Dequantized tensor_q.
|
|
144
|
+
"""
|
|
145
|
+
tensor_e2m1 = u8_unpack_e2m1(tensor_q)
|
|
146
|
+
scale_fp32 = e8m0_to_fp32(scale)
|
|
147
|
+
|
|
148
|
+
return dequantize_tensor(
|
|
149
|
+
tensor_e2m1,
|
|
150
|
+
scale_fp32,
|
|
151
|
+
axis,
|
|
152
|
+
out_dtype,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def quantize_tensor(
|
|
157
|
+
dtype: jnp.dtype,
|
|
158
|
+
tensor: jax.Array,
|
|
159
|
+
axis: int | tuple | None = -1,
|
|
160
|
+
block_size: int | None = None,
|
|
161
|
+
pad_tensor: bool = False,
|
|
162
|
+
) -> tuple[jax.Array, jax.Array]:
|
|
163
|
+
"""Quantize tensor.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
dtype: dtype to perform quantization.
|
|
167
|
+
tensor: Unquantized tensor
|
|
168
|
+
axis: Axis to perform quantization. None denotes per-tensor.
|
|
169
|
+
block_size: Specify block quantization size.
|
|
170
|
+
pad_tensor: Whether to pad the axis along block size.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Tensor quantized to dtype.
|
|
174
|
+
"""
|
|
175
|
+
if axis is None:
|
|
176
|
+
# Perform per-tensor quantization.
|
|
177
|
+
axis = [i for i in range(tensor.ndim)]
|
|
178
|
+
if isinstance(axis, int):
|
|
179
|
+
axis = [axis]
|
|
180
|
+
|
|
181
|
+
orig_shape = tensor.shape
|
|
182
|
+
mask = jnp.ones_like(tensor, jnp.int32)
|
|
183
|
+
|
|
184
|
+
if block_size is not None:
|
|
185
|
+
if isinstance(block_size, int):
|
|
186
|
+
block_size = [block_size] * len(axis)
|
|
187
|
+
|
|
188
|
+
blocked_shape = [[i] for i in orig_shape]
|
|
189
|
+
pad_width = [[0, 0] for _ in range(tensor.ndim)]
|
|
190
|
+
for i, block in zip(axis, block_size):
|
|
191
|
+
num_blocks = (tensor.shape[i] + block - 1) // block
|
|
192
|
+
padding_size = num_blocks * block - tensor.shape[i]
|
|
193
|
+
if padding_size and not pad_tensor:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Unable to perform block quantization. axis={i} of "
|
|
196
|
+
f"{tensor.shape=} is not divisible by {block=}")
|
|
197
|
+
|
|
198
|
+
# Pad the tensor to align with block size.
|
|
199
|
+
pad_width[i][1] = padding_size
|
|
200
|
+
|
|
201
|
+
blocked_shape[i] = (num_blocks, block)
|
|
202
|
+
|
|
203
|
+
# In order to avoid padded values affecting scale value, we pad it
|
|
204
|
+
# using edge value of the tensor.
|
|
205
|
+
tensor = jnp.pad(tensor, pad_width, "edge")
|
|
206
|
+
mask = jnp.pad(mask, pad_width)
|
|
207
|
+
|
|
208
|
+
orig_shape = tensor.shape
|
|
209
|
+
# Convert all axis into positive values.
|
|
210
|
+
axis = sorted([i % tensor.ndim for i in axis])
|
|
211
|
+
# Shift axis by 1 since its original position is now occupied by
|
|
212
|
+
# num_blocks dim. Also, if n axes before an axis was also quantized,
|
|
213
|
+
# shift its position by n.
|
|
214
|
+
axis = [1 + n + i for n, i in enumerate(axis)]
|
|
215
|
+
|
|
216
|
+
# Flatten list of lists that contains (num_blocks, block).
|
|
217
|
+
blocked_shape = list(itertools.chain(*blocked_shape))
|
|
218
|
+
tensor = tensor.reshape(blocked_shape)
|
|
219
|
+
|
|
220
|
+
if jnp.issubdtype(dtype, jnp.integer):
|
|
221
|
+
dtype_info = jnp.iinfo(dtype)
|
|
222
|
+
else:
|
|
223
|
+
dtype_info = jnp.finfo(dtype)
|
|
224
|
+
|
|
225
|
+
dtype_max = float(dtype_info.max)
|
|
226
|
+
dtype_min = float(dtype_info.min)
|
|
227
|
+
|
|
228
|
+
abs_max = jnp.max(jnp.abs(tensor), axis=axis, keepdims=True)
|
|
229
|
+
scale = abs_max / dtype_max
|
|
230
|
+
|
|
231
|
+
tensor_q = jnp.clip(tensor / scale, dtype_min, dtype_max)
|
|
232
|
+
tensor_q = tensor_q.reshape(orig_shape)
|
|
233
|
+
tensor_q = tensor_q.astype(dtype)
|
|
234
|
+
|
|
235
|
+
# To avoid padded values affecting output of quantized matmul, we mask them
|
|
236
|
+
# out with 0s.
|
|
237
|
+
tensor_q = jnp.where(mask, tensor_q, 0)
|
|
238
|
+
|
|
239
|
+
scale = jnp.squeeze(scale, axis).astype(jnp.float32)
|
|
240
|
+
|
|
241
|
+
return tensor_q, scale
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def static_per_tensor_quantize_tensor(
|
|
245
|
+
dtype: jnp.dtype,
|
|
246
|
+
tensor: jax.Array,
|
|
247
|
+
scale: float,
|
|
248
|
+
) -> jax.Array:
|
|
249
|
+
if jnp.issubdtype(dtype, jnp.integer):
|
|
250
|
+
dtype_info = jnp.iinfo(dtype)
|
|
251
|
+
else:
|
|
252
|
+
dtype_info = jnp.finfo(dtype)
|
|
253
|
+
|
|
254
|
+
dtype_max = float(dtype_info.max)
|
|
255
|
+
dtype_min = float(dtype_info.min)
|
|
256
|
+
|
|
257
|
+
return jnp.clip(tensor / scale, dtype_min, dtype_max).astype(dtype)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def quantize_kv(
|
|
261
|
+
dtype: jnp.dtype,
|
|
262
|
+
key: jax.Array,
|
|
263
|
+
value: jax.Array,
|
|
264
|
+
k_scale: float,
|
|
265
|
+
v_scale: float,
|
|
266
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
267
|
+
"""Static quantize key and value tensors."""
|
|
268
|
+
key = static_per_tensor_quantize_tensor(dtype, key, k_scale)
|
|
269
|
+
value = static_per_tensor_quantize_tensor(dtype, value, v_scale)
|
|
270
|
+
return key, value
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import json
|
|
2
16
|
import math
|
|
3
17
|
from dataclasses import asdict, dataclass
|
|
@@ -26,7 +40,7 @@ class ShardingAxisNameBase:
|
|
|
26
40
|
MLP_TENSOR = ('attn_dp', 'model', 'expert')
|
|
27
41
|
MOE_TENSOR = ('attn_dp', 'model')
|
|
28
42
|
EXPERT = ('attn_dp', 'expert', 'model')
|
|
29
|
-
VOCAB = ('expert', 'model')
|
|
43
|
+
VOCAB = ('expert', 'attn_dp', 'model')
|
|
30
44
|
|
|
31
45
|
|
|
32
46
|
class ShardingAxisName2D:
|
|
@@ -119,10 +133,19 @@ class ShardingConfigManager:
|
|
|
119
133
|
False)
|
|
120
134
|
if enable_dp_attention:
|
|
121
135
|
# Replicate attention layer when num_kv_heads < TP
|
|
122
|
-
num_kv_heads = vllm_config.model_config.get_total_num_kv_heads(
|
|
136
|
+
num_kv_heads = 1 if vllm_config.model_config.use_mla else vllm_config.model_config.get_total_num_kv_heads(
|
|
137
|
+
)
|
|
138
|
+
cache_dtype = vllm_config.cache_config.cache_dtype
|
|
139
|
+
if cache_dtype == 'auto':
|
|
140
|
+
cache_dtype = vllm_config.model_config.dtype
|
|
123
141
|
kv_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
124
|
-
|
|
142
|
+
cache_dtype) or jnp.bfloat16
|
|
125
143
|
packing = 4 // jnp.dtype(kv_dtype).itemsize
|
|
144
|
+
|
|
145
|
+
# The default head dim is 128 but 64 is also supported as a special case.
|
|
146
|
+
if vllm_config.model_config.get_head_size() == 64:
|
|
147
|
+
packing *= 2
|
|
148
|
+
|
|
126
149
|
# When num_kv_heads * 2 / packing < TP, tensor parallelism would
|
|
127
150
|
# duplicate KV heads across devices, wasting kv cache memory.
|
|
128
151
|
# Use attention DP instead to reduce per-device num_kv_heads and
|
|
@@ -168,8 +191,8 @@ class ShardingConfigManager:
|
|
|
168
191
|
if sharding_strategy.attention_data_parallelism > 1:
|
|
169
192
|
if not envs.NEW_MODEL_DESIGN:
|
|
170
193
|
raise ValueError(
|
|
171
|
-
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set
|
|
172
|
-
"NEW_MODEL_DESIGN=True
|
|
194
|
+
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set "
|
|
195
|
+
"NEW_MODEL_DESIGN=True")
|
|
173
196
|
|
|
174
197
|
@property
|
|
175
198
|
def total_dp_size(self) -> int:
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import InitVar, dataclass
|
|
2
16
|
from typing import Any, Tuple
|
|
3
17
|
|
|
@@ -5,7 +19,6 @@ import jax
|
|
|
5
19
|
import jax.numpy as jnp
|
|
6
20
|
from flax import nnx
|
|
7
21
|
from flax.typing import Sharding
|
|
8
|
-
from jax.experimental import shard_map
|
|
9
22
|
from jax.sharding import Mesh
|
|
10
23
|
from jax.sharding import PartitionSpec as P
|
|
11
24
|
|
|
@@ -13,6 +26,7 @@ from tpu_inference import utils
|
|
|
13
26
|
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
|
|
14
27
|
ragged_paged_attention
|
|
15
28
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
29
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
16
30
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
17
31
|
from tpu_inference.layers.jax.base import create_param
|
|
18
32
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
@@ -149,9 +163,8 @@ class Attention(nnx.Module):
|
|
|
149
163
|
# q_scale = self._q_scale
|
|
150
164
|
k_scale = self._k_scale
|
|
151
165
|
v_scale = self._v_scale
|
|
152
|
-
k_SKH, v_SKH =
|
|
153
|
-
|
|
154
|
-
k_scale, v_scale)
|
|
166
|
+
k_SKH, v_SKH = quantize_kv(self.kv_cache_quantized_dtype, k_SKH,
|
|
167
|
+
v_SKH, k_scale, v_scale)
|
|
155
168
|
|
|
156
169
|
with jax.named_scope("attn_op"):
|
|
157
170
|
new_kv_cache, outputs_TNH = self.attention(
|
|
@@ -236,12 +249,12 @@ class Attention(nnx.Module):
|
|
|
236
249
|
)
|
|
237
250
|
|
|
238
251
|
output_TNH, kv_cache = jax.jit(
|
|
239
|
-
|
|
252
|
+
jax.shard_map(
|
|
240
253
|
_ragged_paged_attention,
|
|
241
254
|
mesh=mesh,
|
|
242
255
|
in_specs=in_specs,
|
|
243
256
|
out_specs=out_specs,
|
|
244
|
-
|
|
257
|
+
check_vma=False,
|
|
245
258
|
))(
|
|
246
259
|
q_TNH,
|
|
247
260
|
k_SKH,
|