tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# Copyright 2024 The T5X Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Binary search over float32 bits.
|
|
15
|
+
|
|
16
|
+
Includes fast algorithms top-k masking and top-p masking on probability
|
|
17
|
+
distributions.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from typing import Callable, Sequence
|
|
21
|
+
|
|
22
|
+
import jax
|
|
23
|
+
from jax import lax
|
|
24
|
+
from jax import numpy as jnp
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def int32_bsearch(batch_shape: Sequence[int],
|
|
28
|
+
predicate: Callable[[jnp.ndarray], jnp.ndarray]):
|
|
29
|
+
"""Batched binary search over int32 values.
|
|
30
|
+
|
|
31
|
+
For each element of the batch, search for the largest int32 (closest to
|
|
32
|
+
positive infinity) for which the predicate is False. If the predicate is
|
|
33
|
+
always True, returns the minimum int32 value.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
batch_shape: Shape of the search that we're batching over.
|
|
37
|
+
predicate: the query we're searching for. For every batch element, this is
|
|
38
|
+
required to be a monotonic function from int32 to bool. In other words,
|
|
39
|
+
the predicate must return False for all numbers <= some threshold and then
|
|
40
|
+
return True for all numbers > that threshold. The threshold may be
|
|
41
|
+
different for different elements of the batch.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
For each element of the batch, the largest int32 for which the predicate
|
|
45
|
+
returns False. Shape: batch_shape.
|
|
46
|
+
"""
|
|
47
|
+
current_bits = jnp.zeros(batch_shape, dtype=jnp.int32)
|
|
48
|
+
|
|
49
|
+
# bit 31 is special, because it compares in the opposite order of all other
|
|
50
|
+
# bits. we use uint32 due to numpy promotion/casting rules.
|
|
51
|
+
midpoint = current_bits
|
|
52
|
+
predicate_satisfied = predicate(midpoint)
|
|
53
|
+
current_bits = current_bits | jnp.where(predicate_satisfied,
|
|
54
|
+
jnp.uint32(1 << 31), jnp.uint32(0))
|
|
55
|
+
del midpoint, predicate_satisfied
|
|
56
|
+
|
|
57
|
+
def loop_body(i, current_bits):
|
|
58
|
+
bit_index = 30 - i
|
|
59
|
+
bit = jnp.int32(1 << bit_index)
|
|
60
|
+
midpoint = current_bits | bit
|
|
61
|
+
predicate_satisfied = predicate(midpoint)
|
|
62
|
+
current_bits = current_bits | jnp.where(predicate_satisfied,
|
|
63
|
+
jnp.int32(0), bit)
|
|
64
|
+
return current_bits
|
|
65
|
+
|
|
66
|
+
current_bits = lax.fori_loop(0, 31, loop_body, current_bits)
|
|
67
|
+
return current_bits
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _monotonic_int32_to_float32_bit_pattern(x: int) -> int:
|
|
71
|
+
"""Converts an int32 to a float32 bit pattern with consistent ordering.
|
|
72
|
+
|
|
73
|
+
This function is the unique function that is monotonic with respect to the
|
|
74
|
+
floating point total order, see
|
|
75
|
+
https://en.wikipedia.org/wiki/IEEE_754#Total-ordering_predicate. Note that
|
|
76
|
+
this function returns an int32, not a float32. For the function that returns
|
|
77
|
+
float32, see `monotonic_int32_to_float32`.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
x: int bit pattern.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Bit pattern of a float32 number.
|
|
84
|
+
"""
|
|
85
|
+
non_sign_bits = jnp.int32((1 << 31) - 1)
|
|
86
|
+
# See
|
|
87
|
+
# https://stackoverflow.com/questions/20097380/iee-754-total-order-in-standard-c11
|
|
88
|
+
# for the relationship between int32 order and f32 total order, including
|
|
89
|
+
# the "xor trick".
|
|
90
|
+
|
|
91
|
+
# Flip the sort order for numbers where the sign bit is set. On int32,
|
|
92
|
+
# the bit pattern with sign bit set and all other bits clear is the most
|
|
93
|
+
# negative bit pattern (it's int32::MIN), whereas on float32 it's the least
|
|
94
|
+
# negative bit pattern (it's -0.0). Flipping all the non-sign bits makes the
|
|
95
|
+
# int32 sort order consistent with the float32 sort order.
|
|
96
|
+
x = x ^ jnp.where(x < 0, non_sign_bits, jnp.int32(0))
|
|
97
|
+
return x
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _monotonic_int32_to_float32(x: int) -> jax.Array:
|
|
101
|
+
"""Converts an int32 to a float32 with consistent ordering.
|
|
102
|
+
|
|
103
|
+
This function is the unique function that is monotonic with respect to the
|
|
104
|
+
floating point total order, see
|
|
105
|
+
https://en.wikipedia.org/wiki/IEEE_754#Total-ordering_predicate.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
x: int bit pattern.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
float32 number with consistent ordering.
|
|
112
|
+
"""
|
|
113
|
+
x = _monotonic_int32_to_float32_bit_pattern(x)
|
|
114
|
+
return lax.bitcast_convert_type(x, jnp.float32)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def float32_bsearch(batch_shape, predicate):
|
|
118
|
+
"""Binary search on finite float32 numbers.
|
|
119
|
+
|
|
120
|
+
For each element of the batch, this function searches for the largest finite
|
|
121
|
+
non-NaN float32 for which the predicate is False.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
batch_shape: Shape of the search that we're batching over.
|
|
125
|
+
predicate: the query we're searching for. This is required to be monotonic
|
|
126
|
+
with respect to the floating point order, i.e. it must be False for all
|
|
127
|
+
numbers <= a threshold, and then True for all numbers > the threshold. The
|
|
128
|
+
threshold may be different for different elements of the batch.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
For each element of the batch, the largest float32 for which the predicate
|
|
132
|
+
returns False. Shape: f32[batch_shape].
|
|
133
|
+
"""
|
|
134
|
+
exponent_bits = jnp.int32((1 << 31) - (1 << (31 - 8)))
|
|
135
|
+
|
|
136
|
+
def int32_predicate(x):
|
|
137
|
+
x = _monotonic_int32_to_float32_bit_pattern(x)
|
|
138
|
+
is_finite = (x & exponent_bits) != exponent_bits
|
|
139
|
+
|
|
140
|
+
# Non-finite numbers (infinity and NaN) are at the very extremes of the
|
|
141
|
+
# int32 range, i.e. they include int32::MAX and int32::MIN, plus the numbers
|
|
142
|
+
# adjacent to them. For the nonfinite numbers touching int32::MIN, we
|
|
143
|
+
# arrange for them to return False from the predicate, and for the nonfinite
|
|
144
|
+
# numbers touching int32::MAX, we arrange for them to return True from the
|
|
145
|
+
# predicate. x>=0 is an easy way to achieve that.
|
|
146
|
+
predicate_on_nonfinite = x >= 0
|
|
147
|
+
x_float32 = lax.bitcast_convert_type(x, jnp.float32)
|
|
148
|
+
return jnp.where(is_finite, predicate(x_float32),
|
|
149
|
+
predicate_on_nonfinite)
|
|
150
|
+
|
|
151
|
+
# We search over bit patterns, which requires bit shifting and ordering of bit
|
|
152
|
+
# patterns. This is natively supported on int32 but not on float32.
|
|
153
|
+
# Additionally, it's more common to reason about int32 bit arithmetic and
|
|
154
|
+
# ordering than float32 bit arithmetic and ordering, so we do the core of our
|
|
155
|
+
# search in int32. Additionally, this allows us to test the underlying binary
|
|
156
|
+
# search on int32 values.
|
|
157
|
+
#
|
|
158
|
+
# The function _monotonic_int32_to_float32 encapsulates all of the knowledge
|
|
159
|
+
# we need about float32 bit patterns.
|
|
160
|
+
result = int32_bsearch(batch_shape, int32_predicate)
|
|
161
|
+
return _monotonic_int32_to_float32(result)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def topk_mask(x: jnp.ndarray, k: int, replace_val: jnp.ndarray) -> jnp.ndarray:
|
|
165
|
+
"""Sets everything to replace_val, except the top k values per batch element.
|
|
166
|
+
|
|
167
|
+
Sharding considerations: this function does 32 reductions over the vocab_size
|
|
168
|
+
axis of the input array. To avoid excessive latency from these reductions, you
|
|
169
|
+
should ensure that the vocab_size axis is unsharded on input to this function.
|
|
170
|
+
Prefer to shard the batch axes instead.
|
|
171
|
+
|
|
172
|
+
Scratchpad memory considerations: this function is most efficient if the
|
|
173
|
+
entire input array can fit in a fast memory tier. To help ensure this, you may
|
|
174
|
+
wish to split the batch axes into microbatches and the microbatches in a
|
|
175
|
+
sequential loop.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
x: Values before masking. [batch..., vocab_size]
|
|
179
|
+
k: Number of masked values to return. In presence of ties, more than k
|
|
180
|
+
values might be returned.
|
|
181
|
+
replace_val: For the masked values of x, what to overwrite them with.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
masked version of x. [batch..., vocab_size]
|
|
185
|
+
"""
|
|
186
|
+
batch_shape = tuple(list(x.shape)[:-1]) # [batch...]
|
|
187
|
+
|
|
188
|
+
x_for_loop = x
|
|
189
|
+
reduce_axis = x.ndim - 1
|
|
190
|
+
if x.ndim > 1:
|
|
191
|
+
# We're going to be doing 32 reductions over 'reduce_axis'. Generally,
|
|
192
|
+
# reductions over the last dimension are the most expensive, because they
|
|
193
|
+
# involve reducing across vector lanes, which is often not efficient. So
|
|
194
|
+
# we transpose the reduce_axis to be the second-last dimension, to avoid
|
|
195
|
+
# this inefficiency.
|
|
196
|
+
#
|
|
197
|
+
# Normaly the XLA compiler would automatically perform this optimization,
|
|
198
|
+
# but it doesn't yet see through loops to do so. So we do it ourselves.
|
|
199
|
+
x_for_loop = jnp.swapaxes(x_for_loop, -1, -2)
|
|
200
|
+
reduce_axis = x.ndim - 2
|
|
201
|
+
|
|
202
|
+
# x: [batch..., vocab_size, batch]
|
|
203
|
+
def predicate(threshold):
|
|
204
|
+
# threshold: [batch...]
|
|
205
|
+
|
|
206
|
+
# Since we've negated, we now want a predicate that is True for small
|
|
207
|
+
# numbers and False for large numbers. The result of the bsearch is the
|
|
208
|
+
# smallest float32 for which the predicate is False.
|
|
209
|
+
threshold = -threshold
|
|
210
|
+
|
|
211
|
+
threshold = lax.expand_dims(threshold, (reduce_axis, ))
|
|
212
|
+
# threshold: [batch..., 1, last_batch]
|
|
213
|
+
|
|
214
|
+
# count_ge: [batch...]
|
|
215
|
+
count_gt = jnp.sum(x_for_loop > threshold, axis=reduce_axis)
|
|
216
|
+
|
|
217
|
+
return count_gt >= k
|
|
218
|
+
|
|
219
|
+
# cutoff: [batch...]
|
|
220
|
+
cutoff = float32_bsearch(batch_shape, predicate)
|
|
221
|
+
cutoff = -cutoff
|
|
222
|
+
# cutoff: [batch..., 1]
|
|
223
|
+
cutoff = lax.expand_dims(cutoff, (cutoff.ndim, ))
|
|
224
|
+
return jnp.where(x >= cutoff, x, jnp.full_like(x, replace_val))
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def topp_mask(logits: jnp.ndarray, p: float,
|
|
228
|
+
replace_val: jnp.ndarray) -> jnp.ndarray:
|
|
229
|
+
"""Applies top-p masking to logits.
|
|
230
|
+
|
|
231
|
+
Masks logits down to the smallest set of choices, such that the total
|
|
232
|
+
probability mass is >= p. Values in this set are left as they are. All other
|
|
233
|
+
values are set with `replace_val`.
|
|
234
|
+
|
|
235
|
+
Sharding considerations: this function does 33 reductions over the vocab_size
|
|
236
|
+
axis of the input array. To avoid excessive latency from these reductions, you
|
|
237
|
+
should ensure that the vocab_size axis is unsharded on input to this function.
|
|
238
|
+
Prefer to shard the batch axes instead.
|
|
239
|
+
|
|
240
|
+
Scratchpad memory considerations: this function is most efficient if the
|
|
241
|
+
entire input array can fit in a fast memory tier. To help ensure this, you may
|
|
242
|
+
wish to split the batch axes into microbatches and the microbatches in a
|
|
243
|
+
sequential loop.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
logits: Logits before masking. [batch..., vocab_size]
|
|
247
|
+
p: Minimum probability mass requested.
|
|
248
|
+
replace_val: For the masked values of logits, what to overwrite them with.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
masked version of x. [batch..., vocab_size]
|
|
252
|
+
"""
|
|
253
|
+
batch_shape = tuple(list(logits.shape)[:-1]) # [batch...]
|
|
254
|
+
|
|
255
|
+
probs = jax.nn.softmax(logits, axis=-1)
|
|
256
|
+
|
|
257
|
+
probs_for_reduction = probs
|
|
258
|
+
reduce_axis = probs_for_reduction.ndim - 1
|
|
259
|
+
if probs_for_reduction.ndim > 1:
|
|
260
|
+
# We're going to be doing 33 reductions over 'reduce_axis'. Generally,
|
|
261
|
+
# reductions over the last dimension are the most expensive, because they
|
|
262
|
+
# involve reducing across vector lanes, which is often not efficient. So
|
|
263
|
+
# we transpose the reduce_axis to be the second-last dimension, to avoid
|
|
264
|
+
# this inefficiency.
|
|
265
|
+
probs_for_reduction = jnp.swapaxes(probs_for_reduction, -1, -2)
|
|
266
|
+
reduce_axis = probs_for_reduction.ndim - 2
|
|
267
|
+
|
|
268
|
+
# As we increase the threshold, the probability mass decreases, and the number
|
|
269
|
+
# selected decreases.
|
|
270
|
+
#
|
|
271
|
+
# We want the largest threshold with the probability mass >= p. Binary search
|
|
272
|
+
# searches for when the predicate is False, so we negate the output of the
|
|
273
|
+
# predicate, i.e. probability mass < p.
|
|
274
|
+
|
|
275
|
+
# probs_for_reduction: [batch..., vocab_size, batch]
|
|
276
|
+
def predicate(threshold):
|
|
277
|
+
# threshold: [batch...]
|
|
278
|
+
threshold = lax.expand_dims(threshold, (reduce_axis, ))
|
|
279
|
+
# threshold: [batch..., 1, last_batch]
|
|
280
|
+
|
|
281
|
+
# count_ge: [batch...]
|
|
282
|
+
probability_mass = jnp.sum(
|
|
283
|
+
jnp.where(probs_for_reduction >= threshold, probs_for_reduction,
|
|
284
|
+
0.0),
|
|
285
|
+
axis=reduce_axis,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
return probability_mass < p
|
|
289
|
+
|
|
290
|
+
# threshold: [batch...]
|
|
291
|
+
threshold = float32_bsearch(batch_shape, predicate)
|
|
292
|
+
# threshold: [batch..., 1]
|
|
293
|
+
threshold = lax.expand_dims(threshold, (threshold.ndim, ))
|
|
294
|
+
return jnp.where(probs >= threshold, logits,
|
|
295
|
+
jnp.full_like(logits, replace_val))
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
UNQUANTIZED = "unquantized"
|
|
16
|
+
MXFP4 = "mxfp4"
|
|
17
|
+
AWQ = "awq"
|
|
18
|
+
COMPRESSED_TENSORS = "compressed-tensors"
|
|
19
|
+
FP8 = "fp8"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_tpu_quant_method(quant_method: str) -> str:
|
|
23
|
+
return "tpu-" + quant_method
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import itertools
|
|
16
|
+
from typing import Tuple
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
|
|
21
|
+
MXFP4_BLOCK_SIZE = 32
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def quantize_tensor_to_mxfp4_packed(
|
|
25
|
+
tensor: jax.Array,
|
|
26
|
+
axis: int | tuple = -1,
|
|
27
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
28
|
+
"""Quantize a tensor to mxfp4 and pack it into uint8."""
|
|
29
|
+
|
|
30
|
+
# Perform regular block quantization.
|
|
31
|
+
tensor_q, scale = quantize_tensor(
|
|
32
|
+
jnp.float4_e2m1fn,
|
|
33
|
+
tensor,
|
|
34
|
+
axis,
|
|
35
|
+
MXFP4_BLOCK_SIZE,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# last two e2m1 elements will be packed into a single uint8 element.
|
|
39
|
+
bitcast_shape = tensor_q.shape[:-1] + (-1, 2)
|
|
40
|
+
tensor_q = tensor_q.reshape(bitcast_shape)
|
|
41
|
+
tensor_q_packed = jax.lax.bitcast_convert_type(tensor_q, jnp.uint8)
|
|
42
|
+
|
|
43
|
+
# Since TPU does not have native support for e8m0, we convert scale into
|
|
44
|
+
# e8m0 manually and store it as uint8.
|
|
45
|
+
e8m0_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
46
|
+
_, scale_exp = jnp.frexp(scale)
|
|
47
|
+
# Subtract exponents by one since e8m0 has no decimal.
|
|
48
|
+
scale_exp -= 1
|
|
49
|
+
scale_exp = (scale_exp - e8m0_finfo.minexp).astype(jnp.uint8)
|
|
50
|
+
|
|
51
|
+
return tensor_q_packed, scale_exp
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
55
|
+
"""Unpack e2m1 tensor packed into u8."""
|
|
56
|
+
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
57
|
+
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
58
|
+
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
59
|
+
# we flatten them with the last dim.
|
|
60
|
+
return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
64
|
+
"""Convert e8m0 (that was bitcasted to u8) into fp32"""
|
|
65
|
+
assert u8.dtype == jnp.uint8
|
|
66
|
+
|
|
67
|
+
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
68
|
+
exponents = u8.astype(jnp.int32) + e8_finfo.minexp
|
|
69
|
+
ones = jnp.ones_like(u8, dtype=jnp.float32)
|
|
70
|
+
return jnp.ldexp(ones, exponents)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def dequantize_tensor(
|
|
74
|
+
tensor_q: jax.Array,
|
|
75
|
+
scale: jax.Array,
|
|
76
|
+
axis: int | None | tuple = -1,
|
|
77
|
+
out_dtype: jnp.dtype = jnp.bfloat16,
|
|
78
|
+
) -> jax.Array:
|
|
79
|
+
"""Dequantize a quantized tensor
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
tensor_q: Quantized tensor.
|
|
83
|
+
scale: Quantization scale.
|
|
84
|
+
axis: The axis tensor was quantized. None denotes per-tensor.
|
|
85
|
+
out_dtype: Dtype of the output.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Dequantized tensor_q.
|
|
89
|
+
"""
|
|
90
|
+
if axis is None:
|
|
91
|
+
# Perform per-tensor quantization.
|
|
92
|
+
axis = [i for i in range(tensor_q.ndim)]
|
|
93
|
+
if isinstance(axis, int):
|
|
94
|
+
axis = [axis]
|
|
95
|
+
|
|
96
|
+
orig_shape = tensor_q.shape
|
|
97
|
+
if tensor_q.ndim == scale.ndim:
|
|
98
|
+
# Indicates the tensor was block quantized.
|
|
99
|
+
blocked_shape = [[i] for i in orig_shape]
|
|
100
|
+
for i in axis:
|
|
101
|
+
num_blocks = scale.shape[i]
|
|
102
|
+
if tensor_q.shape[i] % num_blocks:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"Unable to perform block dequantization. axis={i} of "
|
|
105
|
+
f"{tensor_q.shape=} is not divisible by {num_blocks=}", )
|
|
106
|
+
block_size = tensor_q.shape[i] // num_blocks
|
|
107
|
+
|
|
108
|
+
blocked_shape[i] = (num_blocks, block_size)
|
|
109
|
+
|
|
110
|
+
# Convert all axis into positive values.
|
|
111
|
+
axis = sorted([(i + tensor_q.ndim) % tensor_q.ndim for i in axis])
|
|
112
|
+
# Shift axis by 1 since its original position is now occupied by
|
|
113
|
+
# num_blocks dim. Also, if n axes before an axis was also quantized,
|
|
114
|
+
# shift its position by n.
|
|
115
|
+
axis = [1 + n + i for n, i in enumerate(axis)]
|
|
116
|
+
|
|
117
|
+
# Flatten list of lists that contains (num_blocks, block).
|
|
118
|
+
blocked_shape = list(itertools.chain(*blocked_shape))
|
|
119
|
+
tensor_q = tensor_q.reshape(blocked_shape)
|
|
120
|
+
|
|
121
|
+
scale = jnp.expand_dims(scale, axis)
|
|
122
|
+
|
|
123
|
+
tensor = (tensor_q.astype(jnp.float32) * scale).astype(out_dtype)
|
|
124
|
+
|
|
125
|
+
return tensor.reshape(orig_shape)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def dequantize_tensor_from_mxfp4_packed(
|
|
129
|
+
tensor_q: jax.Array,
|
|
130
|
+
scale: jax.Array,
|
|
131
|
+
axis: int | tuple = -1,
|
|
132
|
+
out_dtype: jnp.dtype = jnp.bfloat16,
|
|
133
|
+
) -> jax.Array:
|
|
134
|
+
"""Dequantize packed mxfp4 tensor.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
tensor_q: fp4 tensor packed into uint8.
|
|
138
|
+
scale: e8m0 scale packed into uint8.
|
|
139
|
+
axis: The axis tensor was quantized.
|
|
140
|
+
out_dtype: Dtype of the output.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Dequantized tensor_q.
|
|
144
|
+
"""
|
|
145
|
+
tensor_e2m1 = u8_unpack_e2m1(tensor_q)
|
|
146
|
+
scale_fp32 = e8m0_to_fp32(scale)
|
|
147
|
+
|
|
148
|
+
return dequantize_tensor(
|
|
149
|
+
tensor_e2m1,
|
|
150
|
+
scale_fp32,
|
|
151
|
+
axis,
|
|
152
|
+
out_dtype,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def quantize_tensor(
|
|
157
|
+
dtype: jnp.dtype,
|
|
158
|
+
tensor: jax.Array,
|
|
159
|
+
axis: int | tuple | None = -1,
|
|
160
|
+
block_size: int | None = None,
|
|
161
|
+
pad_tensor: bool = False,
|
|
162
|
+
) -> tuple[jax.Array, jax.Array]:
|
|
163
|
+
"""Quantize tensor.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
dtype: dtype to perform quantization.
|
|
167
|
+
tensor: Unquantized tensor
|
|
168
|
+
axis: Axis to perform quantization. None denotes per-tensor.
|
|
169
|
+
block_size: Specify block quantization size.
|
|
170
|
+
pad_tensor: Whether to pad the axis along block size.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Tensor quantized to dtype.
|
|
174
|
+
"""
|
|
175
|
+
if axis is None:
|
|
176
|
+
# Perform per-tensor quantization.
|
|
177
|
+
axis = [i for i in range(tensor.ndim)]
|
|
178
|
+
if isinstance(axis, int):
|
|
179
|
+
axis = [axis]
|
|
180
|
+
|
|
181
|
+
orig_shape = tensor.shape
|
|
182
|
+
mask = jnp.ones_like(tensor, jnp.int32)
|
|
183
|
+
|
|
184
|
+
if block_size is not None:
|
|
185
|
+
if isinstance(block_size, int):
|
|
186
|
+
block_size = [block_size] * len(axis)
|
|
187
|
+
|
|
188
|
+
blocked_shape = [[i] for i in orig_shape]
|
|
189
|
+
pad_width = [[0, 0] for _ in range(tensor.ndim)]
|
|
190
|
+
for i, block in zip(axis, block_size):
|
|
191
|
+
num_blocks = (tensor.shape[i] + block - 1) // block
|
|
192
|
+
padding_size = num_blocks * block - tensor.shape[i]
|
|
193
|
+
if padding_size and not pad_tensor:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Unable to perform block quantization. axis={i} of "
|
|
196
|
+
f"{tensor.shape=} is not divisible by {block=}")
|
|
197
|
+
|
|
198
|
+
# Pad the tensor to align with block size.
|
|
199
|
+
pad_width[i][1] = padding_size
|
|
200
|
+
|
|
201
|
+
blocked_shape[i] = (num_blocks, block)
|
|
202
|
+
|
|
203
|
+
# In order to avoid padded values affecting scale value, we pad it
|
|
204
|
+
# using edge value of the tensor.
|
|
205
|
+
tensor = jnp.pad(tensor, pad_width, "edge")
|
|
206
|
+
mask = jnp.pad(mask, pad_width)
|
|
207
|
+
|
|
208
|
+
orig_shape = tensor.shape
|
|
209
|
+
# Convert all axis into positive values.
|
|
210
|
+
axis = sorted([i % tensor.ndim for i in axis])
|
|
211
|
+
# Shift axis by 1 since its original position is now occupied by
|
|
212
|
+
# num_blocks dim. Also, if n axes before an axis was also quantized,
|
|
213
|
+
# shift its position by n.
|
|
214
|
+
axis = [1 + n + i for n, i in enumerate(axis)]
|
|
215
|
+
|
|
216
|
+
# Flatten list of lists that contains (num_blocks, block).
|
|
217
|
+
blocked_shape = list(itertools.chain(*blocked_shape))
|
|
218
|
+
tensor = tensor.reshape(blocked_shape)
|
|
219
|
+
|
|
220
|
+
if jnp.issubdtype(dtype, jnp.integer):
|
|
221
|
+
dtype_info = jnp.iinfo(dtype)
|
|
222
|
+
else:
|
|
223
|
+
dtype_info = jnp.finfo(dtype)
|
|
224
|
+
|
|
225
|
+
dtype_max = float(dtype_info.max)
|
|
226
|
+
dtype_min = float(dtype_info.min)
|
|
227
|
+
|
|
228
|
+
abs_max = jnp.max(jnp.abs(tensor), axis=axis, keepdims=True)
|
|
229
|
+
scale = abs_max / dtype_max
|
|
230
|
+
|
|
231
|
+
tensor_q = jnp.clip(tensor / scale, dtype_min, dtype_max)
|
|
232
|
+
tensor_q = tensor_q.reshape(orig_shape)
|
|
233
|
+
tensor_q = tensor_q.astype(dtype)
|
|
234
|
+
|
|
235
|
+
# To avoid padded values affecting output of quantized matmul, we mask them
|
|
236
|
+
# out with 0s.
|
|
237
|
+
tensor_q = jnp.where(mask, tensor_q, 0)
|
|
238
|
+
|
|
239
|
+
scale = jnp.squeeze(scale, axis).astype(jnp.float32)
|
|
240
|
+
|
|
241
|
+
return tensor_q, scale
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def static_per_tensor_quantize_tensor(
|
|
245
|
+
dtype: jnp.dtype,
|
|
246
|
+
tensor: jax.Array,
|
|
247
|
+
scale: float,
|
|
248
|
+
) -> jax.Array:
|
|
249
|
+
if jnp.issubdtype(dtype, jnp.integer):
|
|
250
|
+
dtype_info = jnp.iinfo(dtype)
|
|
251
|
+
else:
|
|
252
|
+
dtype_info = jnp.finfo(dtype)
|
|
253
|
+
|
|
254
|
+
dtype_max = float(dtype_info.max)
|
|
255
|
+
dtype_min = float(dtype_info.min)
|
|
256
|
+
|
|
257
|
+
return jnp.clip(tensor / scale, dtype_min, dtype_max).astype(dtype)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def quantize_kv(
|
|
261
|
+
dtype: jnp.dtype,
|
|
262
|
+
key: jax.Array,
|
|
263
|
+
value: jax.Array,
|
|
264
|
+
k_scale: float,
|
|
265
|
+
v_scale: float,
|
|
266
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
267
|
+
"""Static quantize key and value tensors."""
|
|
268
|
+
key = static_per_tensor_quantize_tensor(dtype, key, k_scale)
|
|
269
|
+
value = static_per_tensor_quantize_tensor(dtype, value, v_scale)
|
|
270
|
+
return key, value
|