tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
File without changes
|
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Quantized matmul kernel."""
|
|
3
|
+
|
|
4
|
+
import functools
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from jax._src import dtypes
|
|
9
|
+
from jax.experimental import pallas as pl
|
|
10
|
+
from jax.experimental.pallas import tpu as pltpu
|
|
11
|
+
|
|
12
|
+
from tpu_inference.kernels.quantized_matmul.tuned_block_sizes import (
|
|
13
|
+
TunedValue, get_device_vmem_limit, get_tuned_block_sizes)
|
|
14
|
+
from tpu_inference.kernels.quantized_matmul.util import (get_kernel_name,
|
|
15
|
+
next_multiple,
|
|
16
|
+
unfold_args)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def quantize_array(
|
|
20
|
+
x: jax.Array, # [bs_block_size, in_block_size]
|
|
21
|
+
x_abs_max: jax.Array, # [1, bs_block_size]
|
|
22
|
+
quant_dtype: jnp.dtype,
|
|
23
|
+
):
|
|
24
|
+
is_float = jnp.issubdtype(quant_dtype, jnp.floating)
|
|
25
|
+
dtype_info = jnp.finfo(quant_dtype) if is_float else jnp.iinfo(quant_dtype)
|
|
26
|
+
dtype_max = float(dtype_info.max)
|
|
27
|
+
|
|
28
|
+
# TODO(kyuyeunk): Investigate performance gain from non xlu transpose.
|
|
29
|
+
scale = jnp.transpose(x_abs_max / dtype_max)
|
|
30
|
+
return (x / scale).astype(quant_dtype), scale.astype(jnp.float32)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_vmem_limit(
|
|
34
|
+
n_batch: int,
|
|
35
|
+
n_out: int,
|
|
36
|
+
n_in: int,
|
|
37
|
+
batch_block_size: int,
|
|
38
|
+
out_block_size: int,
|
|
39
|
+
in_block_size: int,
|
|
40
|
+
x_dtype: jnp.dtype,
|
|
41
|
+
x_q_dtype: jnp.dtype,
|
|
42
|
+
w_q_dtype: jnp.dtype,
|
|
43
|
+
scale_dtype: jnp.dtype,
|
|
44
|
+
out_dtype: jnp.dtype,
|
|
45
|
+
acc_dtype: jnp.dtype,
|
|
46
|
+
save_acc: bool,
|
|
47
|
+
save_x_q: bool,
|
|
48
|
+
upper_limit_bytes: int,
|
|
49
|
+
):
|
|
50
|
+
"""Calculate VMEM limit for the kernel."""
|
|
51
|
+
|
|
52
|
+
# Calculate in/out VMEM size.
|
|
53
|
+
x_size = batch_block_size * in_block_size * dtypes.bit_width(x_dtype)
|
|
54
|
+
x_abs_max_size = batch_block_size * dtypes.bit_width(scale_dtype)
|
|
55
|
+
w_q_size = out_block_size * in_block_size * dtypes.bit_width(w_q_dtype)
|
|
56
|
+
w_scale_size = out_block_size * dtypes.bit_width(scale_dtype)
|
|
57
|
+
out_size = batch_block_size * out_block_size * dtypes.bit_width(out_dtype)
|
|
58
|
+
|
|
59
|
+
vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
|
|
60
|
+
vmem_in_out *= 2 # Account for compute and vreg spills.
|
|
61
|
+
|
|
62
|
+
# Account for double buffering.
|
|
63
|
+
# Double buffering is used only if there are multiple blocks per in/out.
|
|
64
|
+
vmem_in_out += x_size if (n_batch > 1 or n_in > 1) else 0
|
|
65
|
+
vmem_in_out += x_abs_max_size if (n_batch > 1) else 0
|
|
66
|
+
vmem_in_out += w_q_size if (n_out > 1 or n_in > 1) else 0
|
|
67
|
+
vmem_in_out += w_scale_size if (n_out > 1) else 0
|
|
68
|
+
vmem_in_out += out_size if (n_batch > 1 or n_out > 1) else 0
|
|
69
|
+
|
|
70
|
+
# Calculate scratch VMEM size.
|
|
71
|
+
acc_size = batch_block_size * out_block_size * dtypes.bit_width(acc_dtype)
|
|
72
|
+
x_q_size = batch_block_size * in_block_size * dtypes.bit_width(x_q_dtype)
|
|
73
|
+
x_scale_size = batch_block_size * dtypes.bit_width(scale_dtype)
|
|
74
|
+
|
|
75
|
+
vmem_scratch = acc_size if save_acc else 0
|
|
76
|
+
vmem_scratch += x_q_size + x_scale_size if save_x_q else 0
|
|
77
|
+
vmem_scratch *= 2 # Account for compute and vreg spills.
|
|
78
|
+
|
|
79
|
+
# Add in/out and scratch VMEM size.
|
|
80
|
+
vmem_used = vmem_in_out + vmem_scratch
|
|
81
|
+
vmem_used_bytes = vmem_used // 8 # Convert bits to bytes.
|
|
82
|
+
# Specify upper limit. Defaults to 96MB.
|
|
83
|
+
vmem_limit_bytes = min(vmem_used_bytes, upper_limit_bytes)
|
|
84
|
+
|
|
85
|
+
return vmem_limit_bytes
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def validate_inputs(
|
|
89
|
+
x: jax.Array,
|
|
90
|
+
w_q: jax.Array,
|
|
91
|
+
w_scale: jax.Array,
|
|
92
|
+
x_abs_max: jax.Array,
|
|
93
|
+
x_q_dtype: jnp.dtype,
|
|
94
|
+
batch_block_size: int,
|
|
95
|
+
out_block_size: int,
|
|
96
|
+
in_block_size: int,
|
|
97
|
+
):
|
|
98
|
+
"""Verify inputs invoking the kernel."""
|
|
99
|
+
|
|
100
|
+
if x.dtype != x_q_dtype:
|
|
101
|
+
# If the input is quantized, then it should be the same subdtype as w_q
|
|
102
|
+
if jnp.issubdtype(x_q_dtype, jnp.integer) != jnp.issubdtype(
|
|
103
|
+
w_q.dtype, jnp.integer):
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f'{x_q_dtype=} and {w_q.dtype=} must be the same int or float type.'
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Verify input shapes.
|
|
109
|
+
if x.shape[1] != w_q.shape[1]:
|
|
110
|
+
raise ValueError(f'{x.shape[1]=} must be equal to {w_q.shape[1]=}')
|
|
111
|
+
if w_q.shape[0] != w_scale.shape[1]:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f'{w_q.shape[0]=} must be equal to {w_scale.shape[1]=}')
|
|
114
|
+
if x_abs_max.shape != (1, x.shape[0]):
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f'{x_abs_max.shape=} must be equal to (1, {x.shape[0]=})')
|
|
117
|
+
if x.shape[0] % batch_block_size != 0:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f'{x.shape[0]=} must be a multiple of {batch_block_size=}')
|
|
120
|
+
if w_q.shape[0] % out_block_size != 0:
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f'{w_q.shape[0]=} must be a multiple of {out_block_size=}')
|
|
123
|
+
if x.shape[1] % in_block_size != 0:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
f'{x.shape[1]=} must be a multiple of {in_block_size=}')
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def matmul_kernel(
|
|
129
|
+
x_ref: jax.Array, # (batch_block_size, in_block_size)
|
|
130
|
+
w_q_ref: jax.Array, # (out_block_size, in_block_size)
|
|
131
|
+
w_scale_ref: jax.Array, # (1, out_block_size)
|
|
132
|
+
x_abs_max_ref: jax.Array, # (1, batch_block_size)
|
|
133
|
+
out_ref: jax.Array, # (batch_block_size, out_block_size)
|
|
134
|
+
acc_scratch: jax.Array, # (batch_block_size, out_block_size)
|
|
135
|
+
x_q_scratch: jax.Array, # (batch_block_size, in_block_size)
|
|
136
|
+
x_scale_scratch: jax.Array, # (batch_block_size, 1)
|
|
137
|
+
*,
|
|
138
|
+
x_q_dtype: jnp.dtype,
|
|
139
|
+
save_acc: bool,
|
|
140
|
+
save_x_q: bool,
|
|
141
|
+
):
|
|
142
|
+
out_idx, in_idx = pl.program_id(1), pl.program_id(2)
|
|
143
|
+
n_in = pl.num_programs(2)
|
|
144
|
+
x_ref_dtype = x_ref.dtype
|
|
145
|
+
|
|
146
|
+
quantize_activation = x_q_dtype != x_ref_dtype
|
|
147
|
+
|
|
148
|
+
# Initialize conditional logic.
|
|
149
|
+
if save_x_q:
|
|
150
|
+
assert quantize_activation
|
|
151
|
+
assert x_q_scratch is not None
|
|
152
|
+
assert x_scale_scratch is not None
|
|
153
|
+
quant = out_idx == 0
|
|
154
|
+
else:
|
|
155
|
+
assert x_q_scratch is None
|
|
156
|
+
assert x_scale_scratch is None
|
|
157
|
+
quant = quantize_activation
|
|
158
|
+
|
|
159
|
+
if save_acc:
|
|
160
|
+
assert acc_scratch is not None
|
|
161
|
+
is_first_step = in_idx == 0
|
|
162
|
+
is_last_step = in_idx == (n_in - 1)
|
|
163
|
+
else:
|
|
164
|
+
assert acc_scratch is None
|
|
165
|
+
is_first_step = True
|
|
166
|
+
is_last_step = True
|
|
167
|
+
|
|
168
|
+
acc_dtype = jnp.float32
|
|
169
|
+
if quantize_activation and jnp.issubdtype(w_q_ref.dtype, jnp.integer):
|
|
170
|
+
acc_dtype = jnp.int32
|
|
171
|
+
|
|
172
|
+
# Start of actual computation logic.
|
|
173
|
+
def matmul_body(quant: bool, is_first_step: bool, is_last_step: bool):
|
|
174
|
+
if quantize_activation:
|
|
175
|
+
if quant:
|
|
176
|
+
x_q_tmp, x_scale_tmp = quantize_array(
|
|
177
|
+
x_ref[...],
|
|
178
|
+
x_abs_max_ref[...],
|
|
179
|
+
x_q_dtype,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if save_x_q:
|
|
183
|
+
x_q_scratch[...] = x_q_tmp
|
|
184
|
+
x_scale_scratch[...] = x_scale_tmp
|
|
185
|
+
|
|
186
|
+
else:
|
|
187
|
+
assert save_x_q
|
|
188
|
+
x_q_tmp = x_q_scratch[...]
|
|
189
|
+
if is_last_step:
|
|
190
|
+
x_scale_tmp = x_scale_scratch[...]
|
|
191
|
+
|
|
192
|
+
acc = jax.lax.dot_general(
|
|
193
|
+
x_q_tmp,
|
|
194
|
+
w_q_ref[...],
|
|
195
|
+
(((1, ), (1, )), ((), ())),
|
|
196
|
+
preferred_element_type=acc_dtype,
|
|
197
|
+
)
|
|
198
|
+
else:
|
|
199
|
+
acc = jax.lax.dot_general(
|
|
200
|
+
x_ref[...],
|
|
201
|
+
w_q_ref[...],
|
|
202
|
+
(((1, ), (1, )), ((), ())),
|
|
203
|
+
preferred_element_type=acc_dtype,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if not is_first_step:
|
|
207
|
+
acc += acc_scratch[...]
|
|
208
|
+
|
|
209
|
+
if is_last_step:
|
|
210
|
+
acc *= w_scale_ref[...]
|
|
211
|
+
if quantize_activation:
|
|
212
|
+
# TODO(kyuyeunk): Investigate caching broadcast.
|
|
213
|
+
acc *= x_scale_tmp
|
|
214
|
+
out_ref[...] = acc.astype(x_ref_dtype)
|
|
215
|
+
else:
|
|
216
|
+
assert save_acc
|
|
217
|
+
acc_scratch[...] = acc
|
|
218
|
+
|
|
219
|
+
unfold_args((quant, is_first_step, is_last_step), (), matmul_body)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@functools.partial(
|
|
223
|
+
jax.jit,
|
|
224
|
+
static_argnames=[
|
|
225
|
+
'x_q_dtype',
|
|
226
|
+
'tuned_value',
|
|
227
|
+
],
|
|
228
|
+
)
|
|
229
|
+
def quantized_matmul_kernel(
|
|
230
|
+
x: jax.Array, # [bs, n_in]
|
|
231
|
+
w_q: jax.Array, # [n_out, n_in]
|
|
232
|
+
w_scale: jax.Array, # [n_out]
|
|
233
|
+
w_zp: jax.Array | None = None, # [n_out]
|
|
234
|
+
block_size: int | None = None,
|
|
235
|
+
x_q_dtype: jnp.dtype | None = None,
|
|
236
|
+
*,
|
|
237
|
+
tuned_value: TunedValue | None = None,
|
|
238
|
+
) -> jax.Array:
|
|
239
|
+
"""Quantized matmul kernel.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
x: Input unquantized array.
|
|
243
|
+
w_q: Weight quantized array. [n_output_features, n_input_features]
|
|
244
|
+
w_scale: Weight quantization scale. [n_output_features]
|
|
245
|
+
w_zp: Weight zero point for asymmetric quantization.
|
|
246
|
+
block_size: Block size for subchannel quantization.
|
|
247
|
+
x_q_dtype: Quantization type of the input. If None or if the value is the
|
|
248
|
+
same as x.dtype, then no quantization is applied.
|
|
249
|
+
tuned_value: Kernel tuned values for optimal performance.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Quantized matmul result.
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
if w_zp is not None:
|
|
256
|
+
raise NotImplementedError('zero_point is not supported.')
|
|
257
|
+
if block_size is not None:
|
|
258
|
+
raise NotImplementedError('block_size is not supported.')
|
|
259
|
+
|
|
260
|
+
if x_q_dtype is None:
|
|
261
|
+
x_q_dtype = x.dtype
|
|
262
|
+
quantize_activation = x_q_dtype != x.dtype
|
|
263
|
+
|
|
264
|
+
# Pallas kernel only has access to a single block of the input. Therefere,
|
|
265
|
+
# for per-token quantization, abs max has to be computed outside of the
|
|
266
|
+
# kernel.
|
|
267
|
+
x_abs_max = jnp.max(jnp.abs(x), axis=-1, keepdims=False) # [bs]
|
|
268
|
+
# Pallas requires minormost dim to be a multiple of sublane size 128.
|
|
269
|
+
# Therefore, instead of using [bs, 1], we reshape this into [1, bs]
|
|
270
|
+
x_abs_max = jnp.expand_dims(x_abs_max, axis=0) # [1, bs]
|
|
271
|
+
assert x_abs_max.shape == (1, x.shape[0])
|
|
272
|
+
|
|
273
|
+
orig_n_batch, orig_n_in = x.shape
|
|
274
|
+
orig_n_out, _ = w_q.shape
|
|
275
|
+
|
|
276
|
+
if tuned_value is None:
|
|
277
|
+
tuned_value = get_tuned_block_sizes(
|
|
278
|
+
n_batch=orig_n_batch,
|
|
279
|
+
n_out=orig_n_out,
|
|
280
|
+
n_in=orig_n_in,
|
|
281
|
+
x_q_dtype=jnp.dtype(x_q_dtype).name,
|
|
282
|
+
w_q_dtype=jnp.dtype(w_q.dtype).name,
|
|
283
|
+
)
|
|
284
|
+
batch_block_size = tuned_value.batch_block_size
|
|
285
|
+
out_block_size = tuned_value.out_block_size
|
|
286
|
+
in_block_size = tuned_value.in_block_size
|
|
287
|
+
|
|
288
|
+
# Pad the inputs to be multiple of block size.
|
|
289
|
+
padded_n_batch = next_multiple(orig_n_batch, batch_block_size)
|
|
290
|
+
if orig_n_batch < padded_n_batch:
|
|
291
|
+
x = jnp.pad(x, ((0, padded_n_batch - orig_n_batch), (0, 0)))
|
|
292
|
+
x_abs_max = jnp.pad(x_abs_max,
|
|
293
|
+
((0, 0), (0, padded_n_batch - orig_n_batch)))
|
|
294
|
+
padded_n_out = next_multiple(orig_n_out, out_block_size)
|
|
295
|
+
if orig_n_out < padded_n_out:
|
|
296
|
+
w_q = jnp.pad(w_q, ((0, padded_n_out - orig_n_out), (0, 0)))
|
|
297
|
+
w_scale = jnp.pad(w_scale, (0, padded_n_out - orig_n_out))
|
|
298
|
+
padded_n_in = next_multiple(orig_n_in, in_block_size)
|
|
299
|
+
if orig_n_in < padded_n_in:
|
|
300
|
+
x = jnp.pad(x, ((0, 0), (0, padded_n_in - orig_n_in)))
|
|
301
|
+
w_q = jnp.pad(w_q, ((0, 0), (0, padded_n_in - orig_n_in)))
|
|
302
|
+
|
|
303
|
+
if w_scale.dtype != jnp.float32:
|
|
304
|
+
w_scale = w_scale.astype(jnp.float32)
|
|
305
|
+
w_scale = jnp.expand_dims(w_scale, axis=0) # [1, n_output_features]
|
|
306
|
+
|
|
307
|
+
n_batch = padded_n_batch // batch_block_size
|
|
308
|
+
n_out = padded_n_out // out_block_size
|
|
309
|
+
n_in = padded_n_in // in_block_size
|
|
310
|
+
|
|
311
|
+
save_acc = n_in > 1
|
|
312
|
+
# Remove redundant input quantization logic by caching quantized input. For
|
|
313
|
+
# best performance, only enable this behavior when single input block is
|
|
314
|
+
# used per batch.
|
|
315
|
+
save_x_q = quantize_activation and n_in == 1 and n_out > 1
|
|
316
|
+
|
|
317
|
+
acc_dtype = jnp.float32
|
|
318
|
+
if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
|
|
319
|
+
acc_dtype = jnp.int32
|
|
320
|
+
|
|
321
|
+
vmem_limit_bytes = get_vmem_limit(
|
|
322
|
+
n_batch=n_batch,
|
|
323
|
+
n_out=n_out,
|
|
324
|
+
n_in=n_in,
|
|
325
|
+
batch_block_size=batch_block_size,
|
|
326
|
+
out_block_size=out_block_size,
|
|
327
|
+
in_block_size=in_block_size,
|
|
328
|
+
x_dtype=x.dtype,
|
|
329
|
+
x_q_dtype=x_q_dtype,
|
|
330
|
+
w_q_dtype=w_q.dtype,
|
|
331
|
+
scale_dtype=jnp.float32,
|
|
332
|
+
out_dtype=x.dtype,
|
|
333
|
+
acc_dtype=acc_dtype,
|
|
334
|
+
save_acc=save_acc,
|
|
335
|
+
save_x_q=save_x_q,
|
|
336
|
+
upper_limit_bytes=get_device_vmem_limit(),
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
kernel = pl.pallas_call(
|
|
340
|
+
functools.partial(
|
|
341
|
+
matmul_kernel,
|
|
342
|
+
x_q_dtype=x_q_dtype,
|
|
343
|
+
save_acc=save_acc,
|
|
344
|
+
save_x_q=save_x_q,
|
|
345
|
+
),
|
|
346
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
347
|
+
num_scalar_prefetch=0,
|
|
348
|
+
in_specs=[
|
|
349
|
+
pl.BlockSpec((batch_block_size, in_block_size), lambda b, o, i:
|
|
350
|
+
(b, i)), # x
|
|
351
|
+
pl.BlockSpec((out_block_size, in_block_size), lambda b, o, i:
|
|
352
|
+
(o, i)), # w_q
|
|
353
|
+
pl.BlockSpec((1, out_block_size), lambda b, o, i:
|
|
354
|
+
(0, o)), # w_scale
|
|
355
|
+
pl.BlockSpec((1, batch_block_size), lambda b, o, i:
|
|
356
|
+
(0, b)), # x_abs_max
|
|
357
|
+
],
|
|
358
|
+
out_specs=pl.BlockSpec((batch_block_size, out_block_size),
|
|
359
|
+
lambda b, o, i: (b, o)),
|
|
360
|
+
scratch_shapes=[
|
|
361
|
+
pltpu.VMEM((batch_block_size, out_block_size), acc_dtype)
|
|
362
|
+
if save_acc else None, # acc_scratch
|
|
363
|
+
pltpu.VMEM((batch_block_size, in_block_size), x_q_dtype)
|
|
364
|
+
if save_x_q else None, # x_q_scratch
|
|
365
|
+
pltpu.VMEM(
|
|
366
|
+
(batch_block_size,
|
|
367
|
+
1), jnp.float32) if save_x_q else None, # x_scale_scratch
|
|
368
|
+
],
|
|
369
|
+
grid=(n_batch, n_out, n_in),
|
|
370
|
+
),
|
|
371
|
+
out_shape=jax.ShapeDtypeStruct((padded_n_batch, padded_n_out),
|
|
372
|
+
x.dtype),
|
|
373
|
+
compiler_params=pltpu.CompilerParams(
|
|
374
|
+
dimension_semantics=('parallel', 'arbitrary', 'arbitrary'),
|
|
375
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
376
|
+
),
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
validate_inputs(
|
|
380
|
+
x=x,
|
|
381
|
+
w_q=w_q,
|
|
382
|
+
w_scale=w_scale,
|
|
383
|
+
x_abs_max=x_abs_max,
|
|
384
|
+
x_q_dtype=x_q_dtype,
|
|
385
|
+
batch_block_size=batch_block_size,
|
|
386
|
+
out_block_size=out_block_size,
|
|
387
|
+
in_block_size=in_block_size,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# The named_scope is used for autotune.
|
|
391
|
+
kernel_name = get_kernel_name(tuned_value)
|
|
392
|
+
with jax.named_scope(kernel_name):
|
|
393
|
+
out = kernel(x, w_q, w_scale, x_abs_max)
|
|
394
|
+
|
|
395
|
+
return out[:orig_n_batch, :orig_n_out]
|