tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Quantized matmul kernel."""
|
|
3
|
+
|
|
4
|
+
import functools
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from jax._src import dtypes
|
|
9
|
+
from jax.experimental import pallas as pl
|
|
10
|
+
from jax.experimental.pallas import tpu as pltpu
|
|
11
|
+
|
|
12
|
+
from tpu_inference.kernels.quantized_matmul import util
|
|
13
|
+
from tpu_inference.kernels.quantized_matmul.tuned_block_sizes import (
|
|
14
|
+
TunedValue, get_device_vmem_limit, get_tuned_block_sizes)
|
|
15
|
+
from tpu_inference.kernels.quantized_matmul.util import (get_kernel_name,
|
|
16
|
+
next_multiple,
|
|
17
|
+
unfold_args)
|
|
18
|
+
|
|
19
|
+
quantize_tensor = util.quantize_tensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def xla_quantized_matmul(
|
|
23
|
+
x: jax.Array,
|
|
24
|
+
w_q: jax.Array,
|
|
25
|
+
w_scale: jax.Array,
|
|
26
|
+
quantize_activation=True,
|
|
27
|
+
) -> jax.Array:
|
|
28
|
+
"""
|
|
29
|
+
Reference (pure JAX) implementation of the quantized matmul kernel below.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
x: Activation.
|
|
33
|
+
w_q: Weight quantized array. [n_output_features, n_input_features]
|
|
34
|
+
w_s: Weight quantization scale. [n_output_features]
|
|
35
|
+
mesh: Mesh to shard on.
|
|
36
|
+
weight_sharding: PartitionSpec for the weight tensor.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Output of the quantized matmul.
|
|
40
|
+
"""
|
|
41
|
+
if quantize_activation:
|
|
42
|
+
acc_dtype = jnp.float32
|
|
43
|
+
if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
|
|
44
|
+
acc_dtype = jnp.int32
|
|
45
|
+
|
|
46
|
+
x_q, x_scale = quantize_tensor(x, w_q.dtype)
|
|
47
|
+
out = jax.lax.dot_general(
|
|
48
|
+
x_q,
|
|
49
|
+
w_q,
|
|
50
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
51
|
+
preferred_element_type=acc_dtype,
|
|
52
|
+
).astype(jnp.float32)
|
|
53
|
+
out *= x_scale
|
|
54
|
+
else:
|
|
55
|
+
out = jax.lax.dot_general(
|
|
56
|
+
x,
|
|
57
|
+
w_q,
|
|
58
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
59
|
+
preferred_element_type=jnp.float32,
|
|
60
|
+
)
|
|
61
|
+
out *= jnp.expand_dims(w_scale, 0)
|
|
62
|
+
return out.astype(x.dtype)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def quantize_array(
|
|
66
|
+
x: jax.Array, # [bs_block_size, in_block_size]
|
|
67
|
+
x_abs_max: jax.Array, # [1, bs_block_size]
|
|
68
|
+
quant_dtype: jnp.dtype,
|
|
69
|
+
):
|
|
70
|
+
is_float = jnp.issubdtype(quant_dtype, jnp.floating)
|
|
71
|
+
dtype_info = jnp.finfo(quant_dtype) if is_float else jnp.iinfo(quant_dtype)
|
|
72
|
+
dtype_max = float(dtype_info.max)
|
|
73
|
+
|
|
74
|
+
# TODO(kyuyeunk): Investigate performance gain from non xlu transpose.
|
|
75
|
+
scale = jnp.transpose(x_abs_max / dtype_max)
|
|
76
|
+
return (x / scale).astype(quant_dtype), scale.astype(jnp.float32)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_vmem_limit(
|
|
80
|
+
n_batch: int,
|
|
81
|
+
n_out: int,
|
|
82
|
+
n_in: int,
|
|
83
|
+
batch_block_size: int,
|
|
84
|
+
out_block_size: int,
|
|
85
|
+
in_block_size: int,
|
|
86
|
+
x_dtype: jnp.dtype,
|
|
87
|
+
x_q_dtype: jnp.dtype,
|
|
88
|
+
w_q_dtype: jnp.dtype,
|
|
89
|
+
scale_dtype: jnp.dtype,
|
|
90
|
+
out_dtype: jnp.dtype,
|
|
91
|
+
acc_dtype: jnp.dtype,
|
|
92
|
+
save_acc: bool,
|
|
93
|
+
save_x_q: bool,
|
|
94
|
+
upper_limit_bytes: int,
|
|
95
|
+
):
|
|
96
|
+
"""Calculate VMEM limit for the kernel."""
|
|
97
|
+
|
|
98
|
+
# Calculate in/out VMEM size.
|
|
99
|
+
x_size = (batch_block_size *
|
|
100
|
+
in_block_size * (dtypes.bit_width(x_dtype) if hasattr(
|
|
101
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(x_dtype)))
|
|
102
|
+
x_abs_max_size = (
|
|
103
|
+
batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
|
|
104
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
|
|
105
|
+
w_q_size = (out_block_size *
|
|
106
|
+
in_block_size * (dtypes.bit_width(w_q_dtype) if hasattr(
|
|
107
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(w_q_dtype)))
|
|
108
|
+
w_scale_size = (out_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
|
|
109
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
|
|
110
|
+
out_size = (batch_block_size *
|
|
111
|
+
out_block_size * (dtypes.bit_width(out_dtype) if hasattr(
|
|
112
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(out_dtype)))
|
|
113
|
+
|
|
114
|
+
vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
|
|
115
|
+
vmem_in_out *= 2 # Account for compute and vreg spills.
|
|
116
|
+
|
|
117
|
+
# Account for double buffering.
|
|
118
|
+
# Double buffering is used only if there are multiple blocks per in/out.
|
|
119
|
+
vmem_in_out += x_size if (n_batch > 1 or n_in > 1) else 0
|
|
120
|
+
vmem_in_out += x_abs_max_size if (n_batch > 1) else 0
|
|
121
|
+
vmem_in_out += w_q_size if (n_out > 1 or n_in > 1) else 0
|
|
122
|
+
vmem_in_out += w_scale_size if (n_out > 1) else 0
|
|
123
|
+
vmem_in_out += out_size if (n_batch > 1 or n_out > 1) else 0
|
|
124
|
+
|
|
125
|
+
# Calculate scratch VMEM size.
|
|
126
|
+
acc_size = (batch_block_size *
|
|
127
|
+
out_block_size * (dtypes.bit_width(acc_dtype) if hasattr(
|
|
128
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(acc_dtype)))
|
|
129
|
+
x_q_size = (batch_block_size *
|
|
130
|
+
in_block_size * (dtypes.bit_width(x_q_dtype) if hasattr(
|
|
131
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(x_q_dtype)))
|
|
132
|
+
x_scale_size = (
|
|
133
|
+
batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
|
|
134
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
|
|
135
|
+
|
|
136
|
+
vmem_scratch = acc_size if save_acc else 0
|
|
137
|
+
vmem_scratch += x_q_size + x_scale_size if save_x_q else 0
|
|
138
|
+
vmem_scratch *= 2 # Account for compute and vreg spills.
|
|
139
|
+
|
|
140
|
+
# Add in/out and scratch VMEM size.
|
|
141
|
+
vmem_used = vmem_in_out + vmem_scratch
|
|
142
|
+
vmem_used_bytes = vmem_used // 8 # Convert bits to bytes.
|
|
143
|
+
# Specify upper limit. Defaults to 96MB.
|
|
144
|
+
vmem_limit_bytes = min(vmem_used_bytes, upper_limit_bytes)
|
|
145
|
+
|
|
146
|
+
return vmem_limit_bytes
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def validate_inputs(
|
|
150
|
+
x: jax.Array,
|
|
151
|
+
w_q: jax.Array,
|
|
152
|
+
w_scale: jax.Array,
|
|
153
|
+
x_abs_max: jax.Array,
|
|
154
|
+
x_q_dtype: jnp.dtype,
|
|
155
|
+
batch_block_size: int,
|
|
156
|
+
out_block_size: int,
|
|
157
|
+
in_block_size: int,
|
|
158
|
+
):
|
|
159
|
+
"""Verify inputs invoking the kernel."""
|
|
160
|
+
|
|
161
|
+
if x.dtype != x_q_dtype:
|
|
162
|
+
# If the input is quantized, then it should be the same subdtype as w_q
|
|
163
|
+
if jnp.issubdtype(x_q_dtype, jnp.integer) != jnp.issubdtype(
|
|
164
|
+
w_q.dtype, jnp.integer):
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f'{x_q_dtype=} and {w_q.dtype=} must be the same int or float type.'
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Verify input shapes.
|
|
170
|
+
if x.shape[1] != w_q.shape[1]:
|
|
171
|
+
raise ValueError(f'{x.shape[1]=} must be equal to {w_q.shape[1]=}')
|
|
172
|
+
if w_q.shape[0] != w_scale.shape[1]:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
f'{w_q.shape[0]=} must be equal to {w_scale.shape[1]=}')
|
|
175
|
+
if x_abs_max.shape != (1, x.shape[0]):
|
|
176
|
+
raise ValueError(
|
|
177
|
+
f'{x_abs_max.shape=} must be equal to (1, {x.shape[0]=})')
|
|
178
|
+
if x.shape[0] % batch_block_size != 0:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f'{x.shape[0]=} must be a multiple of {batch_block_size=}')
|
|
181
|
+
if w_q.shape[0] % out_block_size != 0:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f'{w_q.shape[0]=} must be a multiple of {out_block_size=}')
|
|
184
|
+
if x.shape[1] % in_block_size != 0:
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f'{x.shape[1]=} must be a multiple of {in_block_size=}')
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def matmul_kernel(
|
|
190
|
+
x_ref: jax.Array, # (batch_block_size, in_block_size)
|
|
191
|
+
w_q_ref: jax.Array, # (out_block_size, in_block_size)
|
|
192
|
+
w_scale_ref: jax.Array, # (1, out_block_size)
|
|
193
|
+
x_abs_max_ref: jax.Array, # (1, batch_block_size)
|
|
194
|
+
out_ref: jax.Array, # (batch_block_size, out_block_size)
|
|
195
|
+
acc_scratch: jax.Array, # (batch_block_size, out_block_size)
|
|
196
|
+
x_q_scratch: jax.Array, # (batch_block_size, in_block_size)
|
|
197
|
+
x_scale_scratch: jax.Array, # (batch_block_size, 1)
|
|
198
|
+
*,
|
|
199
|
+
x_q_dtype: jnp.dtype,
|
|
200
|
+
save_acc: bool,
|
|
201
|
+
save_x_q: bool,
|
|
202
|
+
):
|
|
203
|
+
out_idx, in_idx = pl.program_id(1), pl.program_id(2)
|
|
204
|
+
n_in = pl.num_programs(2)
|
|
205
|
+
x_ref_dtype = x_ref.dtype
|
|
206
|
+
|
|
207
|
+
quantize_activation = x_q_dtype != x_ref_dtype
|
|
208
|
+
|
|
209
|
+
# Initialize conditional logic.
|
|
210
|
+
if save_x_q:
|
|
211
|
+
assert quantize_activation
|
|
212
|
+
assert x_q_scratch is not None
|
|
213
|
+
assert x_scale_scratch is not None
|
|
214
|
+
quant = out_idx == 0
|
|
215
|
+
else:
|
|
216
|
+
assert x_q_scratch is None
|
|
217
|
+
assert x_scale_scratch is None
|
|
218
|
+
quant = quantize_activation
|
|
219
|
+
|
|
220
|
+
if save_acc:
|
|
221
|
+
assert acc_scratch is not None
|
|
222
|
+
is_first_step = in_idx == 0
|
|
223
|
+
is_last_step = in_idx == (n_in - 1)
|
|
224
|
+
else:
|
|
225
|
+
assert acc_scratch is None
|
|
226
|
+
is_first_step = True
|
|
227
|
+
is_last_step = True
|
|
228
|
+
|
|
229
|
+
acc_dtype = jnp.float32
|
|
230
|
+
if quantize_activation and jnp.issubdtype(w_q_ref.dtype, jnp.integer):
|
|
231
|
+
acc_dtype = jnp.int32
|
|
232
|
+
|
|
233
|
+
# Start of actual computation logic.
|
|
234
|
+
def matmul_body(quant: bool, is_first_step: bool, is_last_step: bool):
|
|
235
|
+
if quantize_activation:
|
|
236
|
+
if quant:
|
|
237
|
+
x_q_tmp, x_scale_tmp = quantize_array(
|
|
238
|
+
x_ref[...],
|
|
239
|
+
x_abs_max_ref[...],
|
|
240
|
+
x_q_dtype,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
if save_x_q:
|
|
244
|
+
x_q_scratch[...] = x_q_tmp
|
|
245
|
+
x_scale_scratch[...] = x_scale_tmp
|
|
246
|
+
|
|
247
|
+
else:
|
|
248
|
+
assert save_x_q
|
|
249
|
+
x_q_tmp = x_q_scratch[...]
|
|
250
|
+
if is_last_step:
|
|
251
|
+
x_scale_tmp = x_scale_scratch[...]
|
|
252
|
+
|
|
253
|
+
acc = jax.lax.dot_general(
|
|
254
|
+
x_q_tmp,
|
|
255
|
+
w_q_ref[...],
|
|
256
|
+
(((1, ), (1, )), ((), ())),
|
|
257
|
+
preferred_element_type=acc_dtype,
|
|
258
|
+
)
|
|
259
|
+
else:
|
|
260
|
+
acc = jax.lax.dot_general(
|
|
261
|
+
x_ref[...],
|
|
262
|
+
w_q_ref[...],
|
|
263
|
+
(((1, ), (1, )), ((), ())),
|
|
264
|
+
preferred_element_type=acc_dtype,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
if not is_first_step:
|
|
268
|
+
acc += acc_scratch[...]
|
|
269
|
+
|
|
270
|
+
if is_last_step:
|
|
271
|
+
acc *= w_scale_ref[...]
|
|
272
|
+
if quantize_activation:
|
|
273
|
+
# TODO(kyuyeunk): Investigate caching broadcast.
|
|
274
|
+
acc *= x_scale_tmp
|
|
275
|
+
out_ref[...] = acc.astype(x_ref_dtype)
|
|
276
|
+
else:
|
|
277
|
+
assert save_acc
|
|
278
|
+
acc_scratch[...] = acc
|
|
279
|
+
|
|
280
|
+
unfold_args((quant, is_first_step, is_last_step), (), matmul_body)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@functools.partial(
|
|
284
|
+
jax.jit,
|
|
285
|
+
static_argnames=[
|
|
286
|
+
'x_q_dtype',
|
|
287
|
+
'tuned_value',
|
|
288
|
+
],
|
|
289
|
+
)
|
|
290
|
+
def quantized_matmul_kernel(
|
|
291
|
+
x: jax.Array, # [bs, n_in]
|
|
292
|
+
w_q: jax.Array, # [n_out, n_in]
|
|
293
|
+
w_scale: jax.Array, # [n_out]
|
|
294
|
+
w_zp: jax.Array | None = None, # [n_out]
|
|
295
|
+
block_size: int | None = None,
|
|
296
|
+
x_q_dtype: jnp.dtype | None = None,
|
|
297
|
+
*,
|
|
298
|
+
tuned_value: TunedValue | None = None,
|
|
299
|
+
) -> jax.Array:
|
|
300
|
+
"""Quantized matmul kernel.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
x: Input unquantized array.
|
|
304
|
+
w_q: Weight quantized array. [n_output_features, n_input_features]
|
|
305
|
+
w_scale: Weight quantization scale. [n_output_features]
|
|
306
|
+
w_zp: Weight zero point for asymmetric quantization.
|
|
307
|
+
block_size: Block size for subchannel quantization.
|
|
308
|
+
x_q_dtype: Quantization type of the input. If None or if the value is the
|
|
309
|
+
same as x.dtype, then no quantization is applied.
|
|
310
|
+
tuned_value: Kernel tuned values for optimal performance.
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Quantized matmul result.
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
if w_zp is not None:
|
|
317
|
+
raise NotImplementedError('zero_point is not supported.')
|
|
318
|
+
if block_size is not None:
|
|
319
|
+
raise NotImplementedError('block_size is not supported.')
|
|
320
|
+
|
|
321
|
+
if x_q_dtype is None:
|
|
322
|
+
x_q_dtype = x.dtype
|
|
323
|
+
quantize_activation = x_q_dtype != x.dtype
|
|
324
|
+
|
|
325
|
+
# Pallas kernel only has access to a single block of the input. Therefere,
|
|
326
|
+
# for per-token quantization, abs max has to be computed outside of the
|
|
327
|
+
# kernel.
|
|
328
|
+
x_abs_max = jnp.max(jnp.abs(x), axis=-1, keepdims=False) # [bs]
|
|
329
|
+
# Pallas requires minormost dim to be a multiple of sublane size 128.
|
|
330
|
+
# Therefore, instead of using [bs, 1], we reshape this into [1, bs]
|
|
331
|
+
x_abs_max = jnp.expand_dims(x_abs_max, axis=0) # [1, bs]
|
|
332
|
+
assert x_abs_max.shape == (1, x.shape[0])
|
|
333
|
+
|
|
334
|
+
orig_n_batch, orig_n_in = x.shape
|
|
335
|
+
orig_n_out, _ = w_q.shape
|
|
336
|
+
|
|
337
|
+
if tuned_value is None:
|
|
338
|
+
tuned_value = get_tuned_block_sizes(
|
|
339
|
+
n_batch=orig_n_batch,
|
|
340
|
+
n_out=orig_n_out,
|
|
341
|
+
n_in=orig_n_in,
|
|
342
|
+
x_q_dtype=jnp.dtype(x_q_dtype).name,
|
|
343
|
+
w_q_dtype=jnp.dtype(w_q.dtype).name,
|
|
344
|
+
)
|
|
345
|
+
batch_block_size = tuned_value.batch_block_size
|
|
346
|
+
out_block_size = tuned_value.out_block_size
|
|
347
|
+
in_block_size = tuned_value.in_block_size
|
|
348
|
+
|
|
349
|
+
# Pad the inputs to be multiple of block size.
|
|
350
|
+
padded_n_batch = next_multiple(orig_n_batch, batch_block_size)
|
|
351
|
+
if orig_n_batch < padded_n_batch:
|
|
352
|
+
x = jnp.pad(x, ((0, padded_n_batch - orig_n_batch), (0, 0)))
|
|
353
|
+
x_abs_max = jnp.pad(x_abs_max,
|
|
354
|
+
((0, 0), (0, padded_n_batch - orig_n_batch)))
|
|
355
|
+
padded_n_out = next_multiple(orig_n_out, out_block_size)
|
|
356
|
+
if orig_n_out < padded_n_out:
|
|
357
|
+
w_q = jnp.pad(w_q, ((0, padded_n_out - orig_n_out), (0, 0)))
|
|
358
|
+
w_scale = jnp.pad(w_scale, (0, padded_n_out - orig_n_out))
|
|
359
|
+
padded_n_in = next_multiple(orig_n_in, in_block_size)
|
|
360
|
+
if orig_n_in < padded_n_in:
|
|
361
|
+
x = jnp.pad(x, ((0, 0), (0, padded_n_in - orig_n_in)))
|
|
362
|
+
w_q = jnp.pad(w_q, ((0, 0), (0, padded_n_in - orig_n_in)))
|
|
363
|
+
|
|
364
|
+
if w_scale.dtype != jnp.float32:
|
|
365
|
+
w_scale = w_scale.astype(jnp.float32)
|
|
366
|
+
w_scale = jnp.expand_dims(w_scale, axis=0) # [1, n_output_features]
|
|
367
|
+
|
|
368
|
+
n_batch = padded_n_batch // batch_block_size
|
|
369
|
+
n_out = padded_n_out // out_block_size
|
|
370
|
+
n_in = padded_n_in // in_block_size
|
|
371
|
+
|
|
372
|
+
save_acc = n_in > 1
|
|
373
|
+
# Remove redundant input quantization logic by caching quantized input. For
|
|
374
|
+
# best performance, only enable this behavior when single input block is
|
|
375
|
+
# used per batch.
|
|
376
|
+
save_x_q = quantize_activation and n_in == 1 and n_out > 1
|
|
377
|
+
|
|
378
|
+
acc_dtype = jnp.float32
|
|
379
|
+
if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
|
|
380
|
+
acc_dtype = jnp.int32
|
|
381
|
+
|
|
382
|
+
vmem_limit_bytes = get_vmem_limit(
|
|
383
|
+
n_batch=n_batch,
|
|
384
|
+
n_out=n_out,
|
|
385
|
+
n_in=n_in,
|
|
386
|
+
batch_block_size=batch_block_size,
|
|
387
|
+
out_block_size=out_block_size,
|
|
388
|
+
in_block_size=in_block_size,
|
|
389
|
+
x_dtype=x.dtype,
|
|
390
|
+
x_q_dtype=x_q_dtype,
|
|
391
|
+
w_q_dtype=w_q.dtype,
|
|
392
|
+
scale_dtype=jnp.float32,
|
|
393
|
+
out_dtype=x.dtype,
|
|
394
|
+
acc_dtype=acc_dtype,
|
|
395
|
+
save_acc=save_acc,
|
|
396
|
+
save_x_q=save_x_q,
|
|
397
|
+
upper_limit_bytes=get_device_vmem_limit(),
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
kernel = pl.pallas_call(
|
|
401
|
+
functools.partial(
|
|
402
|
+
matmul_kernel,
|
|
403
|
+
x_q_dtype=x_q_dtype,
|
|
404
|
+
save_acc=save_acc,
|
|
405
|
+
save_x_q=save_x_q,
|
|
406
|
+
),
|
|
407
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
408
|
+
num_scalar_prefetch=0,
|
|
409
|
+
in_specs=[
|
|
410
|
+
pl.BlockSpec((batch_block_size, in_block_size), lambda b, o, i:
|
|
411
|
+
(b, i)), # x
|
|
412
|
+
pl.BlockSpec((out_block_size, in_block_size), lambda b, o, i:
|
|
413
|
+
(o, i)), # w_q
|
|
414
|
+
pl.BlockSpec((1, out_block_size), lambda b, o, i:
|
|
415
|
+
(0, o)), # w_scale
|
|
416
|
+
pl.BlockSpec((1, batch_block_size), lambda b, o, i:
|
|
417
|
+
(0, b)), # x_abs_max
|
|
418
|
+
],
|
|
419
|
+
out_specs=pl.BlockSpec((batch_block_size, out_block_size),
|
|
420
|
+
lambda b, o, i: (b, o)),
|
|
421
|
+
scratch_shapes=[
|
|
422
|
+
pltpu.VMEM((batch_block_size, out_block_size), acc_dtype)
|
|
423
|
+
if save_acc else None, # acc_scratch
|
|
424
|
+
pltpu.VMEM((batch_block_size, in_block_size), x_q_dtype)
|
|
425
|
+
if save_x_q else None, # x_q_scratch
|
|
426
|
+
pltpu.VMEM(
|
|
427
|
+
(batch_block_size,
|
|
428
|
+
1), jnp.float32) if save_x_q else None, # x_scale_scratch
|
|
429
|
+
],
|
|
430
|
+
grid=(n_batch, n_out, n_in),
|
|
431
|
+
),
|
|
432
|
+
out_shape=jax.ShapeDtypeStruct((padded_n_batch, padded_n_out),
|
|
433
|
+
x.dtype),
|
|
434
|
+
compiler_params=pltpu.CompilerParams(
|
|
435
|
+
dimension_semantics=('parallel', 'arbitrary', 'arbitrary'),
|
|
436
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
437
|
+
),
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
validate_inputs(
|
|
441
|
+
x=x,
|
|
442
|
+
w_q=w_q,
|
|
443
|
+
w_scale=w_scale,
|
|
444
|
+
x_abs_max=x_abs_max,
|
|
445
|
+
x_q_dtype=x_q_dtype,
|
|
446
|
+
batch_block_size=batch_block_size,
|
|
447
|
+
out_block_size=out_block_size,
|
|
448
|
+
in_block_size=in_block_size,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# The named_scope is used for autotune.
|
|
452
|
+
kernel_name = get_kernel_name(tuned_value)
|
|
453
|
+
with jax.named_scope(kernel_name):
|
|
454
|
+
out = kernel(x, w_q, w_scale, x_abs_max)
|
|
455
|
+
|
|
456
|
+
return out[:orig_n_batch, :orig_n_out]
|