tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import itertools
|
|
16
|
+
from typing import Tuple
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
|
|
21
|
+
MXFP4_BLOCK_SIZE = 32
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def quantize_tensor_to_mxfp4_packed(
|
|
25
|
+
tensor: jax.Array,
|
|
26
|
+
axis: int | tuple = -1,
|
|
27
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
28
|
+
"""Quantize a tensor to mxfp4 and pack it into uint8."""
|
|
29
|
+
|
|
30
|
+
# Perform regular block quantization.
|
|
31
|
+
tensor_q, scale = quantize_tensor(
|
|
32
|
+
jnp.float4_e2m1fn,
|
|
33
|
+
tensor,
|
|
34
|
+
axis,
|
|
35
|
+
MXFP4_BLOCK_SIZE,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# last two e2m1 elements will be packed into a single uint8 element.
|
|
39
|
+
bitcast_shape = tensor_q.shape[:-1] + (-1, 2)
|
|
40
|
+
tensor_q = tensor_q.reshape(bitcast_shape)
|
|
41
|
+
tensor_q_packed = jax.lax.bitcast_convert_type(tensor_q, jnp.uint8)
|
|
42
|
+
|
|
43
|
+
# Since TPU does not have native support for e8m0, we convert scale into
|
|
44
|
+
# e8m0 manually and store it as uint8.
|
|
45
|
+
e8m0_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
46
|
+
_, scale_exp = jnp.frexp(scale)
|
|
47
|
+
# Subtract exponents by one since e8m0 has no decimal.
|
|
48
|
+
scale_exp -= 1
|
|
49
|
+
scale_exp = (scale_exp - e8m0_finfo.minexp).astype(jnp.uint8)
|
|
50
|
+
|
|
51
|
+
return tensor_q_packed, scale_exp
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
55
|
+
"""Unpack e2m1 tensor that was packed into u8."""
|
|
56
|
+
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
57
|
+
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
58
|
+
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
59
|
+
# we flatten them with the last dim.
|
|
60
|
+
return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
64
|
+
"""Convert e8m0 (that was bitcasted to u8) into fp32."""
|
|
65
|
+
assert u8.dtype == jnp.uint8
|
|
66
|
+
|
|
67
|
+
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
68
|
+
exponents = u8.astype(jnp.int32) + e8_finfo.minexp
|
|
69
|
+
ones = jnp.ones_like(u8, dtype=jnp.float32)
|
|
70
|
+
return jnp.ldexp(ones, exponents)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def awq_u32_unpack_u4(awq_u32_packed: jax.Array) -> jax.Array:
|
|
74
|
+
"""Unpack u4 tensor that was packed into u32 in awq ordering."""
|
|
75
|
+
|
|
76
|
+
awq_u4 = jax.lax.bitcast_convert_type(awq_u32_packed, jnp.uint4)
|
|
77
|
+
|
|
78
|
+
# AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
|
|
79
|
+
# Following list maps the order used by AWQ into an ascending order.
|
|
80
|
+
reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
|
|
81
|
+
u4 = awq_u4[..., reverse_awq_order]
|
|
82
|
+
return jnp.reshape(u4, u4.shape[:-2] + (-1, ))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def dequantize_tensor(
|
|
86
|
+
tensor_q: jax.Array,
|
|
87
|
+
scale: jax.Array,
|
|
88
|
+
axis: int | None | tuple = -1,
|
|
89
|
+
out_dtype: jnp.dtype = jnp.bfloat16,
|
|
90
|
+
) -> jax.Array:
|
|
91
|
+
"""Dequantize a quantized tensor
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
tensor_q: Quantized tensor.
|
|
95
|
+
scale: Quantization scale.
|
|
96
|
+
axis: The axis tensor was quantized. None denotes per-tensor.
|
|
97
|
+
out_dtype: Dtype of the output.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Dequantized tensor_q.
|
|
101
|
+
"""
|
|
102
|
+
if axis is None:
|
|
103
|
+
# Perform per-tensor quantization.
|
|
104
|
+
axis = [i for i in range(tensor_q.ndim)]
|
|
105
|
+
if isinstance(axis, int):
|
|
106
|
+
axis = [axis]
|
|
107
|
+
|
|
108
|
+
orig_shape = tensor_q.shape
|
|
109
|
+
if tensor_q.ndim == scale.ndim:
|
|
110
|
+
# Indicates the tensor was block quantized.
|
|
111
|
+
blocked_shape = [[i] for i in orig_shape]
|
|
112
|
+
for i in axis:
|
|
113
|
+
num_blocks = scale.shape[i]
|
|
114
|
+
if tensor_q.shape[i] % num_blocks:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"Unable to perform block dequantization. axis={i} of "
|
|
117
|
+
f"{tensor_q.shape=} is not divisible by {num_blocks=}", )
|
|
118
|
+
block_size = tensor_q.shape[i] // num_blocks
|
|
119
|
+
|
|
120
|
+
blocked_shape[i] = (num_blocks, block_size)
|
|
121
|
+
|
|
122
|
+
# Convert all axis into positive values.
|
|
123
|
+
axis = sorted([(i + tensor_q.ndim) % tensor_q.ndim for i in axis])
|
|
124
|
+
# Shift axis by 1 since its original position is now occupied by
|
|
125
|
+
# num_blocks dim. Also, if n axes before an axis was also quantized,
|
|
126
|
+
# shift its position by n.
|
|
127
|
+
axis = [1 + n + i for n, i in enumerate(axis)]
|
|
128
|
+
|
|
129
|
+
# Flatten list of lists that contains (num_blocks, block).
|
|
130
|
+
blocked_shape = list(itertools.chain(*blocked_shape))
|
|
131
|
+
tensor_q = tensor_q.reshape(blocked_shape)
|
|
132
|
+
|
|
133
|
+
scale = jnp.expand_dims(scale, axis)
|
|
134
|
+
|
|
135
|
+
tensor = (tensor_q.astype(jnp.float32) * scale).astype(out_dtype)
|
|
136
|
+
|
|
137
|
+
return tensor.reshape(orig_shape)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def dequantize_tensor_from_mxfp4_packed(
|
|
141
|
+
tensor_q: jax.Array,
|
|
142
|
+
scale: jax.Array,
|
|
143
|
+
axis: int | tuple = -1,
|
|
144
|
+
out_dtype: jnp.dtype = jnp.bfloat16,
|
|
145
|
+
) -> jax.Array:
|
|
146
|
+
"""Dequantize packed mxfp4 tensor.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
tensor_q: fp4 tensor packed into uint8.
|
|
150
|
+
scale: e8m0 scale packed into uint8.
|
|
151
|
+
axis: The axis tensor was quantized.
|
|
152
|
+
out_dtype: Dtype of the output.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Dequantized tensor_q.
|
|
156
|
+
"""
|
|
157
|
+
tensor_e2m1 = u8_unpack_e2m1(tensor_q)
|
|
158
|
+
scale_fp32 = e8m0_to_fp32(scale)
|
|
159
|
+
|
|
160
|
+
return dequantize_tensor(
|
|
161
|
+
tensor_e2m1,
|
|
162
|
+
scale_fp32,
|
|
163
|
+
axis,
|
|
164
|
+
out_dtype,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def quantize_tensor(
|
|
169
|
+
dtype: jnp.dtype,
|
|
170
|
+
tensor: jax.Array,
|
|
171
|
+
axis: int | tuple | None = -1,
|
|
172
|
+
block_size: int | None = None,
|
|
173
|
+
pad_tensor: bool = False,
|
|
174
|
+
) -> tuple[jax.Array, jax.Array]:
|
|
175
|
+
"""Quantize tensor.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
dtype: dtype to perform quantization.
|
|
179
|
+
tensor: Unquantized tensor
|
|
180
|
+
axis: Axis to perform quantization. None denotes per-tensor.
|
|
181
|
+
block_size: Specify block quantization size.
|
|
182
|
+
pad_tensor: Whether to pad the axis along block size.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Tensor quantized to dtype.
|
|
186
|
+
"""
|
|
187
|
+
if axis is None:
|
|
188
|
+
# Perform per-tensor quantization.
|
|
189
|
+
axis = [i for i in range(tensor.ndim)]
|
|
190
|
+
if isinstance(axis, int):
|
|
191
|
+
axis = [axis]
|
|
192
|
+
|
|
193
|
+
orig_shape = tensor.shape
|
|
194
|
+
mask = jnp.ones_like(tensor, jnp.int32)
|
|
195
|
+
|
|
196
|
+
if block_size is not None:
|
|
197
|
+
if isinstance(block_size, int):
|
|
198
|
+
block_size = [block_size] * len(axis)
|
|
199
|
+
|
|
200
|
+
blocked_shape = [[i] for i in orig_shape]
|
|
201
|
+
pad_width = [[0, 0] for _ in range(tensor.ndim)]
|
|
202
|
+
for i, block in zip(axis, block_size):
|
|
203
|
+
num_blocks = (tensor.shape[i] + block - 1) // block
|
|
204
|
+
padding_size = num_blocks * block - tensor.shape[i]
|
|
205
|
+
if padding_size and not pad_tensor:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"Unable to perform block quantization. axis={i} of "
|
|
208
|
+
f"{tensor.shape=} is not divisible by {block=}")
|
|
209
|
+
|
|
210
|
+
# Pad the tensor to align with block size.
|
|
211
|
+
pad_width[i][1] = padding_size
|
|
212
|
+
|
|
213
|
+
blocked_shape[i] = (num_blocks, block)
|
|
214
|
+
|
|
215
|
+
# In order to avoid padded values affecting scale value, we pad it
|
|
216
|
+
# using edge value of the tensor.
|
|
217
|
+
tensor = jnp.pad(tensor, pad_width, "edge")
|
|
218
|
+
mask = jnp.pad(mask, pad_width)
|
|
219
|
+
|
|
220
|
+
orig_shape = tensor.shape
|
|
221
|
+
# Convert all axis into positive values.
|
|
222
|
+
axis = sorted([i % tensor.ndim for i in axis])
|
|
223
|
+
# Shift axis by 1 since its original position is now occupied by
|
|
224
|
+
# num_blocks dim. Also, if n axes before an axis was also quantized,
|
|
225
|
+
# shift its position by n.
|
|
226
|
+
axis = [1 + n + i for n, i in enumerate(axis)]
|
|
227
|
+
|
|
228
|
+
# Flatten list of lists that contains (num_blocks, block).
|
|
229
|
+
blocked_shape = list(itertools.chain(*blocked_shape))
|
|
230
|
+
tensor = tensor.reshape(blocked_shape)
|
|
231
|
+
|
|
232
|
+
if jnp.issubdtype(dtype, jnp.integer):
|
|
233
|
+
dtype_info = jnp.iinfo(dtype)
|
|
234
|
+
else:
|
|
235
|
+
dtype_info = jnp.finfo(dtype)
|
|
236
|
+
|
|
237
|
+
dtype_max = float(dtype_info.max)
|
|
238
|
+
dtype_min = float(dtype_info.min)
|
|
239
|
+
|
|
240
|
+
abs_max = jnp.max(jnp.abs(tensor), axis=axis, keepdims=True)
|
|
241
|
+
scale = abs_max / dtype_max
|
|
242
|
+
|
|
243
|
+
tensor_q = jnp.clip(tensor / scale, dtype_min, dtype_max)
|
|
244
|
+
tensor_q = tensor_q.reshape(orig_shape)
|
|
245
|
+
tensor_q = tensor_q.astype(dtype)
|
|
246
|
+
|
|
247
|
+
# To avoid padded values affecting output of quantized matmul, we mask them
|
|
248
|
+
# out with 0s.
|
|
249
|
+
tensor_q = jnp.where(mask, tensor_q, 0)
|
|
250
|
+
|
|
251
|
+
scale = jnp.squeeze(scale, axis).astype(jnp.float32)
|
|
252
|
+
|
|
253
|
+
return tensor_q, scale
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def static_per_tensor_quantize_tensor(
|
|
257
|
+
dtype: jnp.dtype,
|
|
258
|
+
tensor: jax.Array,
|
|
259
|
+
scale: float,
|
|
260
|
+
) -> jax.Array:
|
|
261
|
+
if jnp.issubdtype(dtype, jnp.integer):
|
|
262
|
+
dtype_info = jnp.iinfo(dtype)
|
|
263
|
+
else:
|
|
264
|
+
dtype_info = jnp.finfo(dtype)
|
|
265
|
+
|
|
266
|
+
dtype_max = float(dtype_info.max)
|
|
267
|
+
dtype_min = float(dtype_info.min)
|
|
268
|
+
|
|
269
|
+
return jnp.clip(tensor / scale, dtype_min, dtype_max).astype(dtype)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def quantize_kv(
|
|
273
|
+
dtype: jnp.dtype,
|
|
274
|
+
key: jax.Array,
|
|
275
|
+
value: jax.Array,
|
|
276
|
+
k_scale: float,
|
|
277
|
+
v_scale: float,
|
|
278
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
279
|
+
"""Static quantize key and value tensors."""
|
|
280
|
+
key = static_per_tensor_quantize_tensor(dtype, key, k_scale)
|
|
281
|
+
value = static_per_tensor_quantize_tensor(dtype, value, v_scale)
|
|
282
|
+
return key, value
|
|
@@ -1,6 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import json
|
|
2
16
|
import math
|
|
3
|
-
import os
|
|
4
17
|
from dataclasses import asdict, dataclass
|
|
5
18
|
from typing import TYPE_CHECKING, List, Optional
|
|
6
19
|
|
|
@@ -8,7 +21,7 @@ import jax.numpy as jnp
|
|
|
8
21
|
import numpy as np
|
|
9
22
|
from jax.sharding import Mesh
|
|
10
23
|
|
|
11
|
-
from tpu_inference import utils
|
|
24
|
+
from tpu_inference import envs, utils
|
|
12
25
|
|
|
13
26
|
if TYPE_CHECKING:
|
|
14
27
|
from vllm.v1.configs.vllm_config import VllmConfig
|
|
@@ -27,7 +40,7 @@ class ShardingAxisNameBase:
|
|
|
27
40
|
MLP_TENSOR = ('attn_dp', 'model', 'expert')
|
|
28
41
|
MOE_TENSOR = ('attn_dp', 'model')
|
|
29
42
|
EXPERT = ('attn_dp', 'expert', 'model')
|
|
30
|
-
VOCAB = ('expert', 'model')
|
|
43
|
+
VOCAB = ('expert', 'attn_dp', 'model')
|
|
31
44
|
|
|
32
45
|
|
|
33
46
|
class ShardingAxisName2D:
|
|
@@ -48,7 +61,7 @@ class ShardingAxisName2D:
|
|
|
48
61
|
|
|
49
62
|
|
|
50
63
|
try:
|
|
51
|
-
_use_base_sharding =
|
|
64
|
+
_use_base_sharding = envs.NEW_MODEL_DESIGN
|
|
52
65
|
if _use_base_sharding:
|
|
53
66
|
ShardingAxisName = ShardingAxisNameBase
|
|
54
67
|
else:
|
|
@@ -120,10 +133,19 @@ class ShardingConfigManager:
|
|
|
120
133
|
False)
|
|
121
134
|
if enable_dp_attention:
|
|
122
135
|
# Replicate attention layer when num_kv_heads < TP
|
|
123
|
-
num_kv_heads = vllm_config.model_config.get_total_num_kv_heads(
|
|
136
|
+
num_kv_heads = 1 if vllm_config.model_config.use_mla else vllm_config.model_config.get_total_num_kv_heads(
|
|
137
|
+
)
|
|
138
|
+
cache_dtype = vllm_config.cache_config.cache_dtype
|
|
139
|
+
if cache_dtype == 'auto':
|
|
140
|
+
cache_dtype = vllm_config.model_config.dtype
|
|
124
141
|
kv_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
125
|
-
|
|
142
|
+
cache_dtype) or jnp.bfloat16
|
|
126
143
|
packing = 4 // jnp.dtype(kv_dtype).itemsize
|
|
144
|
+
|
|
145
|
+
# The default head dim is 128 but 64 is also supported as a special case.
|
|
146
|
+
if vllm_config.model_config.get_head_size() == 64:
|
|
147
|
+
packing *= 2
|
|
148
|
+
|
|
127
149
|
# When num_kv_heads * 2 / packing < TP, tensor parallelism would
|
|
128
150
|
# duplicate KV heads across devices, wasting kv cache memory.
|
|
129
151
|
# Use attention DP instead to reduce per-device num_kv_heads and
|
|
@@ -166,10 +188,11 @@ class ShardingConfigManager:
|
|
|
166
188
|
f"LoRA is not supported with data parallelism "
|
|
167
189
|
f"(DP size: {total_dp_size}). Please disable LoRA or "
|
|
168
190
|
f"set data parallelism to 1.")
|
|
169
|
-
|
|
191
|
+
if sharding_strategy.attention_data_parallelism > 1:
|
|
192
|
+
if not envs.NEW_MODEL_DESIGN:
|
|
170
193
|
raise ValueError(
|
|
171
|
-
"Must run DP with NEW_MODEL_DESIGN enabled. Please set
|
|
172
|
-
"NEW_MODEL_DESIGN=True
|
|
194
|
+
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set "
|
|
195
|
+
"NEW_MODEL_DESIGN=True")
|
|
173
196
|
|
|
174
197
|
@property
|
|
175
198
|
def total_dp_size(self) -> int:
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array,
|
|
20
|
+
split_sizes: list[int],
|
|
21
|
+
n_shards: int, dim: int):
|
|
22
|
+
"""
|
|
23
|
+
Reorder a replicated concatenated tensor such that when sharded on multiple chips, each shard is a concatenation of the shards of the individual tensors.
|
|
24
|
+
For example, let the concatenated_tensor be:
|
|
25
|
+
AAAAAAAAAAAABBBBBBBBCCCC
|
|
26
|
+
12 As 8 Bs 4 Cs
|
|
27
|
+
and let the split_sizes = [12, 8, 4] and n_shards = 4.
|
|
28
|
+
The output is:
|
|
29
|
+
AAABBCAAABBCAAABBCAAABBC
|
|
30
|
+
In other words, it reorders the input tensor into 4 segements, with each segment corresponding to a shard and being AAABBC.
|
|
31
|
+
Args:
|
|
32
|
+
concatenated_tensor: the tensor, concatenated on the dimension specified by `dim`.
|
|
33
|
+
split_sizes: each individual tensor's size on the dimension specified by `dim`.
|
|
34
|
+
n_shards: num of shards.
|
|
35
|
+
dim: the dimension on which the concatenated_tensor is concatenated.
|
|
36
|
+
"""
|
|
37
|
+
# Split the concatenated tensor into individual tensors.
|
|
38
|
+
if dim < 0:
|
|
39
|
+
dim += concatenated_tensor.ndim
|
|
40
|
+
split_tensors = []
|
|
41
|
+
start_offset = 0
|
|
42
|
+
old_shape = concatenated_tensor.shape
|
|
43
|
+
# New shape ensures each split_tensor[i] maps to a tensor in ith shards
|
|
44
|
+
new_shape = old_shape[:dim] + (n_shards, -1) + old_shape[dim + 1:]
|
|
45
|
+
for split_size in split_sizes:
|
|
46
|
+
split_tensor = jax.lax.slice_in_dim(concatenated_tensor,
|
|
47
|
+
start_offset,
|
|
48
|
+
start_offset + split_size,
|
|
49
|
+
axis=dim)
|
|
50
|
+
split_tensors.append(split_tensor.reshape(new_shape))
|
|
51
|
+
start_offset += split_size
|
|
52
|
+
# While maintaining 0th dim as a shard dim, we concatenate along 1th dim to
|
|
53
|
+
# to create concatenated tnensor where 0th dim maps to shard dim.
|
|
54
|
+
reordered_tensor = jnp.concatenate(split_tensors, axis=dim + 1)
|
|
55
|
+
return reordered_tensor.reshape(old_shape)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def slice_sharded_tensor_for_concatenation(sharded_tensor: jax.Array,
|
|
59
|
+
split_sizes: list[int],
|
|
60
|
+
n_shards: int):
|
|
61
|
+
"""
|
|
62
|
+
Slice the input tensor which is sharded on multiple chips (on the last dim) into individual tensors with the same sharding.
|
|
63
|
+
For example, let the sharded_tensor be:
|
|
64
|
+
AAABBC | AAABBC | AAABBC | AAABBC
|
|
65
|
+
Shard0 Shard1 Shard2 Shard3
|
|
66
|
+
and let the split_sizes = [12, 8, 4] and n_shards = 4.
|
|
67
|
+
The output is a list of 3 tensors:
|
|
68
|
+
AAA | AAA | AAA | AAA
|
|
69
|
+
BB | BB | BB | BB
|
|
70
|
+
C | C | C | C
|
|
71
|
+
Shard0 Shard1 Shard2 Shard3
|
|
72
|
+
In other words, each individual tensor is a slice of the input tensor with the same sharding.
|
|
73
|
+
Args:
|
|
74
|
+
sharded_tensor: the input tensor, sharded on the last dim.
|
|
75
|
+
split_sizes: each individual tensor's size on the last dim.
|
|
76
|
+
n_shards: num of shards.
|
|
77
|
+
"""
|
|
78
|
+
new_shape = sharded_tensor.shape[:-1] + (n_shards, -1)
|
|
79
|
+
# New shape ensures each sharded_tensor[:, i] maps to a tensor in ith shards
|
|
80
|
+
sharded_tensor = sharded_tensor.reshape(new_shape)
|
|
81
|
+
|
|
82
|
+
split_tensors = []
|
|
83
|
+
start_offset = 0
|
|
84
|
+
for split_size in split_sizes:
|
|
85
|
+
assert split_size % n_shards == 0
|
|
86
|
+
sz = split_size // n_shards # size of this split tensor per shard
|
|
87
|
+
end_offset = start_offset + sz
|
|
88
|
+
# Because we are slicing over last dim, sharding dim remains intact.
|
|
89
|
+
# Therefore, splitting happens locally.
|
|
90
|
+
split_tensor = sharded_tensor[..., start_offset:end_offset]
|
|
91
|
+
split_tensors.append(split_tensor.reshape(new_shape[:-2] + (-1, )))
|
|
92
|
+
start_offset = end_offset
|
|
93
|
+
|
|
94
|
+
return split_tensors
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import InitVar, dataclass
|
|
2
16
|
from typing import Any, Tuple
|
|
3
17
|
|
|
@@ -5,7 +19,6 @@ import jax
|
|
|
5
19
|
import jax.numpy as jnp
|
|
6
20
|
from flax import nnx
|
|
7
21
|
from flax.typing import Sharding
|
|
8
|
-
from jax.experimental import shard_map
|
|
9
22
|
from jax.sharding import Mesh
|
|
10
23
|
from jax.sharding import PartitionSpec as P
|
|
11
24
|
|
|
@@ -13,6 +26,7 @@ from tpu_inference import utils
|
|
|
13
26
|
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
|
|
14
27
|
ragged_paged_attention
|
|
15
28
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
29
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
16
30
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
17
31
|
from tpu_inference.layers.jax.base import create_param
|
|
18
32
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
@@ -149,9 +163,8 @@ class Attention(nnx.Module):
|
|
|
149
163
|
# q_scale = self._q_scale
|
|
150
164
|
k_scale = self._k_scale
|
|
151
165
|
v_scale = self._v_scale
|
|
152
|
-
k_SKH, v_SKH =
|
|
153
|
-
|
|
154
|
-
k_scale, v_scale)
|
|
166
|
+
k_SKH, v_SKH = quantize_kv(self.kv_cache_quantized_dtype, k_SKH,
|
|
167
|
+
v_SKH, k_scale, v_scale)
|
|
155
168
|
|
|
156
169
|
with jax.named_scope("attn_op"):
|
|
157
170
|
new_kv_cache, outputs_TNH = self.attention(
|
|
@@ -236,12 +249,12 @@ class Attention(nnx.Module):
|
|
|
236
249
|
)
|
|
237
250
|
|
|
238
251
|
output_TNH, kv_cache = jax.jit(
|
|
239
|
-
|
|
252
|
+
jax.shard_map(
|
|
240
253
|
_ragged_paged_attention,
|
|
241
254
|
mesh=mesh,
|
|
242
255
|
in_specs=in_specs,
|
|
243
256
|
out_specs=out_specs,
|
|
244
|
-
|
|
257
|
+
check_vma=False,
|
|
245
258
|
))(
|
|
246
259
|
q_TNH,
|
|
247
260
|
k_SKH,
|