tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,456 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Quantized matmul kernel."""
3
+
4
+ import functools
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from jax._src import dtypes
9
+ from jax.experimental import pallas as pl
10
+ from jax.experimental.pallas import tpu as pltpu
11
+
12
+ from tpu_inference.kernels.quantized_matmul import util
13
+ from tpu_inference.kernels.quantized_matmul.tuned_block_sizes import (
14
+ TunedValue, get_device_vmem_limit, get_tuned_block_sizes)
15
+ from tpu_inference.kernels.quantized_matmul.util import (get_kernel_name,
16
+ next_multiple,
17
+ unfold_args)
18
+
19
+ quantize_tensor = util.quantize_tensor
20
+
21
+
22
+ def xla_quantized_matmul(
23
+ x: jax.Array,
24
+ w_q: jax.Array,
25
+ w_scale: jax.Array,
26
+ quantize_activation=True,
27
+ ) -> jax.Array:
28
+ """
29
+ Reference (pure JAX) implementation of the quantized matmul kernel below.
30
+
31
+ Args:
32
+ x: Activation.
33
+ w_q: Weight quantized array. [n_output_features, n_input_features]
34
+ w_s: Weight quantization scale. [n_output_features]
35
+ mesh: Mesh to shard on.
36
+ weight_sharding: PartitionSpec for the weight tensor.
37
+
38
+ Returns:
39
+ Output of the quantized matmul.
40
+ """
41
+ if quantize_activation:
42
+ acc_dtype = jnp.float32
43
+ if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
44
+ acc_dtype = jnp.int32
45
+
46
+ x_q, x_scale = quantize_tensor(x, w_q.dtype)
47
+ out = jax.lax.dot_general(
48
+ x_q,
49
+ w_q,
50
+ dimension_numbers=(((1, ), (1, )), ((), ())),
51
+ preferred_element_type=acc_dtype,
52
+ ).astype(jnp.float32)
53
+ out *= x_scale
54
+ else:
55
+ out = jax.lax.dot_general(
56
+ x,
57
+ w_q,
58
+ dimension_numbers=(((1, ), (1, )), ((), ())),
59
+ preferred_element_type=jnp.float32,
60
+ )
61
+ out *= jnp.expand_dims(w_scale, 0)
62
+ return out.astype(x.dtype)
63
+
64
+
65
+ def quantize_array(
66
+ x: jax.Array, # [bs_block_size, in_block_size]
67
+ x_abs_max: jax.Array, # [1, bs_block_size]
68
+ quant_dtype: jnp.dtype,
69
+ ):
70
+ is_float = jnp.issubdtype(quant_dtype, jnp.floating)
71
+ dtype_info = jnp.finfo(quant_dtype) if is_float else jnp.iinfo(quant_dtype)
72
+ dtype_max = float(dtype_info.max)
73
+
74
+ # TODO(kyuyeunk): Investigate performance gain from non xlu transpose.
75
+ scale = jnp.transpose(x_abs_max / dtype_max)
76
+ return (x / scale).astype(quant_dtype), scale.astype(jnp.float32)
77
+
78
+
79
+ def get_vmem_limit(
80
+ n_batch: int,
81
+ n_out: int,
82
+ n_in: int,
83
+ batch_block_size: int,
84
+ out_block_size: int,
85
+ in_block_size: int,
86
+ x_dtype: jnp.dtype,
87
+ x_q_dtype: jnp.dtype,
88
+ w_q_dtype: jnp.dtype,
89
+ scale_dtype: jnp.dtype,
90
+ out_dtype: jnp.dtype,
91
+ acc_dtype: jnp.dtype,
92
+ save_acc: bool,
93
+ save_x_q: bool,
94
+ upper_limit_bytes: int,
95
+ ):
96
+ """Calculate VMEM limit for the kernel."""
97
+
98
+ # Calculate in/out VMEM size.
99
+ x_size = (batch_block_size *
100
+ in_block_size * (dtypes.bit_width(x_dtype) if hasattr(
101
+ dtypes, "bit_width") else dtypes.itemsize_bits(x_dtype)))
102
+ x_abs_max_size = (
103
+ batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
104
+ dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
105
+ w_q_size = (out_block_size *
106
+ in_block_size * (dtypes.bit_width(w_q_dtype) if hasattr(
107
+ dtypes, "bit_width") else dtypes.itemsize_bits(w_q_dtype)))
108
+ w_scale_size = (out_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
109
+ dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
110
+ out_size = (batch_block_size *
111
+ out_block_size * (dtypes.bit_width(out_dtype) if hasattr(
112
+ dtypes, "bit_width") else dtypes.itemsize_bits(out_dtype)))
113
+
114
+ vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
115
+ vmem_in_out *= 2 # Account for compute and vreg spills.
116
+
117
+ # Account for double buffering.
118
+ # Double buffering is used only if there are multiple blocks per in/out.
119
+ vmem_in_out += x_size if (n_batch > 1 or n_in > 1) else 0
120
+ vmem_in_out += x_abs_max_size if (n_batch > 1) else 0
121
+ vmem_in_out += w_q_size if (n_out > 1 or n_in > 1) else 0
122
+ vmem_in_out += w_scale_size if (n_out > 1) else 0
123
+ vmem_in_out += out_size if (n_batch > 1 or n_out > 1) else 0
124
+
125
+ # Calculate scratch VMEM size.
126
+ acc_size = (batch_block_size *
127
+ out_block_size * (dtypes.bit_width(acc_dtype) if hasattr(
128
+ dtypes, "bit_width") else dtypes.itemsize_bits(acc_dtype)))
129
+ x_q_size = (batch_block_size *
130
+ in_block_size * (dtypes.bit_width(x_q_dtype) if hasattr(
131
+ dtypes, "bit_width") else dtypes.itemsize_bits(x_q_dtype)))
132
+ x_scale_size = (
133
+ batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
134
+ dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
135
+
136
+ vmem_scratch = acc_size if save_acc else 0
137
+ vmem_scratch += x_q_size + x_scale_size if save_x_q else 0
138
+ vmem_scratch *= 2 # Account for compute and vreg spills.
139
+
140
+ # Add in/out and scratch VMEM size.
141
+ vmem_used = vmem_in_out + vmem_scratch
142
+ vmem_used_bytes = vmem_used // 8 # Convert bits to bytes.
143
+ # Specify upper limit. Defaults to 96MB.
144
+ vmem_limit_bytes = min(vmem_used_bytes, upper_limit_bytes)
145
+
146
+ return vmem_limit_bytes
147
+
148
+
149
+ def validate_inputs(
150
+ x: jax.Array,
151
+ w_q: jax.Array,
152
+ w_scale: jax.Array,
153
+ x_abs_max: jax.Array,
154
+ x_q_dtype: jnp.dtype,
155
+ batch_block_size: int,
156
+ out_block_size: int,
157
+ in_block_size: int,
158
+ ):
159
+ """Verify inputs invoking the kernel."""
160
+
161
+ if x.dtype != x_q_dtype:
162
+ # If the input is quantized, then it should be the same subdtype as w_q
163
+ if jnp.issubdtype(x_q_dtype, jnp.integer) != jnp.issubdtype(
164
+ w_q.dtype, jnp.integer):
165
+ raise ValueError(
166
+ f'{x_q_dtype=} and {w_q.dtype=} must be the same int or float type.'
167
+ )
168
+
169
+ # Verify input shapes.
170
+ if x.shape[1] != w_q.shape[1]:
171
+ raise ValueError(f'{x.shape[1]=} must be equal to {w_q.shape[1]=}')
172
+ if w_q.shape[0] != w_scale.shape[1]:
173
+ raise ValueError(
174
+ f'{w_q.shape[0]=} must be equal to {w_scale.shape[1]=}')
175
+ if x_abs_max.shape != (1, x.shape[0]):
176
+ raise ValueError(
177
+ f'{x_abs_max.shape=} must be equal to (1, {x.shape[0]=})')
178
+ if x.shape[0] % batch_block_size != 0:
179
+ raise ValueError(
180
+ f'{x.shape[0]=} must be a multiple of {batch_block_size=}')
181
+ if w_q.shape[0] % out_block_size != 0:
182
+ raise ValueError(
183
+ f'{w_q.shape[0]=} must be a multiple of {out_block_size=}')
184
+ if x.shape[1] % in_block_size != 0:
185
+ raise ValueError(
186
+ f'{x.shape[1]=} must be a multiple of {in_block_size=}')
187
+
188
+
189
+ def matmul_kernel(
190
+ x_ref: jax.Array, # (batch_block_size, in_block_size)
191
+ w_q_ref: jax.Array, # (out_block_size, in_block_size)
192
+ w_scale_ref: jax.Array, # (1, out_block_size)
193
+ x_abs_max_ref: jax.Array, # (1, batch_block_size)
194
+ out_ref: jax.Array, # (batch_block_size, out_block_size)
195
+ acc_scratch: jax.Array, # (batch_block_size, out_block_size)
196
+ x_q_scratch: jax.Array, # (batch_block_size, in_block_size)
197
+ x_scale_scratch: jax.Array, # (batch_block_size, 1)
198
+ *,
199
+ x_q_dtype: jnp.dtype,
200
+ save_acc: bool,
201
+ save_x_q: bool,
202
+ ):
203
+ out_idx, in_idx = pl.program_id(1), pl.program_id(2)
204
+ n_in = pl.num_programs(2)
205
+ x_ref_dtype = x_ref.dtype
206
+
207
+ quantize_activation = x_q_dtype != x_ref_dtype
208
+
209
+ # Initialize conditional logic.
210
+ if save_x_q:
211
+ assert quantize_activation
212
+ assert x_q_scratch is not None
213
+ assert x_scale_scratch is not None
214
+ quant = out_idx == 0
215
+ else:
216
+ assert x_q_scratch is None
217
+ assert x_scale_scratch is None
218
+ quant = quantize_activation
219
+
220
+ if save_acc:
221
+ assert acc_scratch is not None
222
+ is_first_step = in_idx == 0
223
+ is_last_step = in_idx == (n_in - 1)
224
+ else:
225
+ assert acc_scratch is None
226
+ is_first_step = True
227
+ is_last_step = True
228
+
229
+ acc_dtype = jnp.float32
230
+ if quantize_activation and jnp.issubdtype(w_q_ref.dtype, jnp.integer):
231
+ acc_dtype = jnp.int32
232
+
233
+ # Start of actual computation logic.
234
+ def matmul_body(quant: bool, is_first_step: bool, is_last_step: bool):
235
+ if quantize_activation:
236
+ if quant:
237
+ x_q_tmp, x_scale_tmp = quantize_array(
238
+ x_ref[...],
239
+ x_abs_max_ref[...],
240
+ x_q_dtype,
241
+ )
242
+
243
+ if save_x_q:
244
+ x_q_scratch[...] = x_q_tmp
245
+ x_scale_scratch[...] = x_scale_tmp
246
+
247
+ else:
248
+ assert save_x_q
249
+ x_q_tmp = x_q_scratch[...]
250
+ if is_last_step:
251
+ x_scale_tmp = x_scale_scratch[...]
252
+
253
+ acc = jax.lax.dot_general(
254
+ x_q_tmp,
255
+ w_q_ref[...],
256
+ (((1, ), (1, )), ((), ())),
257
+ preferred_element_type=acc_dtype,
258
+ )
259
+ else:
260
+ acc = jax.lax.dot_general(
261
+ x_ref[...],
262
+ w_q_ref[...],
263
+ (((1, ), (1, )), ((), ())),
264
+ preferred_element_type=acc_dtype,
265
+ )
266
+
267
+ if not is_first_step:
268
+ acc += acc_scratch[...]
269
+
270
+ if is_last_step:
271
+ acc *= w_scale_ref[...]
272
+ if quantize_activation:
273
+ # TODO(kyuyeunk): Investigate caching broadcast.
274
+ acc *= x_scale_tmp
275
+ out_ref[...] = acc.astype(x_ref_dtype)
276
+ else:
277
+ assert save_acc
278
+ acc_scratch[...] = acc
279
+
280
+ unfold_args((quant, is_first_step, is_last_step), (), matmul_body)
281
+
282
+
283
+ @functools.partial(
284
+ jax.jit,
285
+ static_argnames=[
286
+ 'x_q_dtype',
287
+ 'tuned_value',
288
+ ],
289
+ )
290
+ def quantized_matmul_kernel(
291
+ x: jax.Array, # [bs, n_in]
292
+ w_q: jax.Array, # [n_out, n_in]
293
+ w_scale: jax.Array, # [n_out]
294
+ w_zp: jax.Array | None = None, # [n_out]
295
+ block_size: int | None = None,
296
+ x_q_dtype: jnp.dtype | None = None,
297
+ *,
298
+ tuned_value: TunedValue | None = None,
299
+ ) -> jax.Array:
300
+ """Quantized matmul kernel.
301
+
302
+ Args:
303
+ x: Input unquantized array.
304
+ w_q: Weight quantized array. [n_output_features, n_input_features]
305
+ w_scale: Weight quantization scale. [n_output_features]
306
+ w_zp: Weight zero point for asymmetric quantization.
307
+ block_size: Block size for subchannel quantization.
308
+ x_q_dtype: Quantization type of the input. If None or if the value is the
309
+ same as x.dtype, then no quantization is applied.
310
+ tuned_value: Kernel tuned values for optimal performance.
311
+
312
+ Returns:
313
+ Quantized matmul result.
314
+ """
315
+
316
+ if w_zp is not None:
317
+ raise NotImplementedError('zero_point is not supported.')
318
+ if block_size is not None:
319
+ raise NotImplementedError('block_size is not supported.')
320
+
321
+ if x_q_dtype is None:
322
+ x_q_dtype = x.dtype
323
+ quantize_activation = x_q_dtype != x.dtype
324
+
325
+ # Pallas kernel only has access to a single block of the input. Therefere,
326
+ # for per-token quantization, abs max has to be computed outside of the
327
+ # kernel.
328
+ x_abs_max = jnp.max(jnp.abs(x), axis=-1, keepdims=False) # [bs]
329
+ # Pallas requires minormost dim to be a multiple of sublane size 128.
330
+ # Therefore, instead of using [bs, 1], we reshape this into [1, bs]
331
+ x_abs_max = jnp.expand_dims(x_abs_max, axis=0) # [1, bs]
332
+ assert x_abs_max.shape == (1, x.shape[0])
333
+
334
+ orig_n_batch, orig_n_in = x.shape
335
+ orig_n_out, _ = w_q.shape
336
+
337
+ if tuned_value is None:
338
+ tuned_value = get_tuned_block_sizes(
339
+ n_batch=orig_n_batch,
340
+ n_out=orig_n_out,
341
+ n_in=orig_n_in,
342
+ x_q_dtype=jnp.dtype(x_q_dtype).name,
343
+ w_q_dtype=jnp.dtype(w_q.dtype).name,
344
+ )
345
+ batch_block_size = tuned_value.batch_block_size
346
+ out_block_size = tuned_value.out_block_size
347
+ in_block_size = tuned_value.in_block_size
348
+
349
+ # Pad the inputs to be multiple of block size.
350
+ padded_n_batch = next_multiple(orig_n_batch, batch_block_size)
351
+ if orig_n_batch < padded_n_batch:
352
+ x = jnp.pad(x, ((0, padded_n_batch - orig_n_batch), (0, 0)))
353
+ x_abs_max = jnp.pad(x_abs_max,
354
+ ((0, 0), (0, padded_n_batch - orig_n_batch)))
355
+ padded_n_out = next_multiple(orig_n_out, out_block_size)
356
+ if orig_n_out < padded_n_out:
357
+ w_q = jnp.pad(w_q, ((0, padded_n_out - orig_n_out), (0, 0)))
358
+ w_scale = jnp.pad(w_scale, (0, padded_n_out - orig_n_out))
359
+ padded_n_in = next_multiple(orig_n_in, in_block_size)
360
+ if orig_n_in < padded_n_in:
361
+ x = jnp.pad(x, ((0, 0), (0, padded_n_in - orig_n_in)))
362
+ w_q = jnp.pad(w_q, ((0, 0), (0, padded_n_in - orig_n_in)))
363
+
364
+ if w_scale.dtype != jnp.float32:
365
+ w_scale = w_scale.astype(jnp.float32)
366
+ w_scale = jnp.expand_dims(w_scale, axis=0) # [1, n_output_features]
367
+
368
+ n_batch = padded_n_batch // batch_block_size
369
+ n_out = padded_n_out // out_block_size
370
+ n_in = padded_n_in // in_block_size
371
+
372
+ save_acc = n_in > 1
373
+ # Remove redundant input quantization logic by caching quantized input. For
374
+ # best performance, only enable this behavior when single input block is
375
+ # used per batch.
376
+ save_x_q = quantize_activation and n_in == 1 and n_out > 1
377
+
378
+ acc_dtype = jnp.float32
379
+ if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
380
+ acc_dtype = jnp.int32
381
+
382
+ vmem_limit_bytes = get_vmem_limit(
383
+ n_batch=n_batch,
384
+ n_out=n_out,
385
+ n_in=n_in,
386
+ batch_block_size=batch_block_size,
387
+ out_block_size=out_block_size,
388
+ in_block_size=in_block_size,
389
+ x_dtype=x.dtype,
390
+ x_q_dtype=x_q_dtype,
391
+ w_q_dtype=w_q.dtype,
392
+ scale_dtype=jnp.float32,
393
+ out_dtype=x.dtype,
394
+ acc_dtype=acc_dtype,
395
+ save_acc=save_acc,
396
+ save_x_q=save_x_q,
397
+ upper_limit_bytes=get_device_vmem_limit(),
398
+ )
399
+
400
+ kernel = pl.pallas_call(
401
+ functools.partial(
402
+ matmul_kernel,
403
+ x_q_dtype=x_q_dtype,
404
+ save_acc=save_acc,
405
+ save_x_q=save_x_q,
406
+ ),
407
+ grid_spec=pltpu.PrefetchScalarGridSpec(
408
+ num_scalar_prefetch=0,
409
+ in_specs=[
410
+ pl.BlockSpec((batch_block_size, in_block_size), lambda b, o, i:
411
+ (b, i)), # x
412
+ pl.BlockSpec((out_block_size, in_block_size), lambda b, o, i:
413
+ (o, i)), # w_q
414
+ pl.BlockSpec((1, out_block_size), lambda b, o, i:
415
+ (0, o)), # w_scale
416
+ pl.BlockSpec((1, batch_block_size), lambda b, o, i:
417
+ (0, b)), # x_abs_max
418
+ ],
419
+ out_specs=pl.BlockSpec((batch_block_size, out_block_size),
420
+ lambda b, o, i: (b, o)),
421
+ scratch_shapes=[
422
+ pltpu.VMEM((batch_block_size, out_block_size), acc_dtype)
423
+ if save_acc else None, # acc_scratch
424
+ pltpu.VMEM((batch_block_size, in_block_size), x_q_dtype)
425
+ if save_x_q else None, # x_q_scratch
426
+ pltpu.VMEM(
427
+ (batch_block_size,
428
+ 1), jnp.float32) if save_x_q else None, # x_scale_scratch
429
+ ],
430
+ grid=(n_batch, n_out, n_in),
431
+ ),
432
+ out_shape=jax.ShapeDtypeStruct((padded_n_batch, padded_n_out),
433
+ x.dtype),
434
+ compiler_params=pltpu.CompilerParams(
435
+ dimension_semantics=('parallel', 'arbitrary', 'arbitrary'),
436
+ vmem_limit_bytes=vmem_limit_bytes,
437
+ ),
438
+ )
439
+
440
+ validate_inputs(
441
+ x=x,
442
+ w_q=w_q,
443
+ w_scale=w_scale,
444
+ x_abs_max=x_abs_max,
445
+ x_q_dtype=x_q_dtype,
446
+ batch_block_size=batch_block_size,
447
+ out_block_size=out_block_size,
448
+ in_block_size=in_block_size,
449
+ )
450
+
451
+ # The named_scope is used for autotune.
452
+ kernel_name = get_kernel_name(tuned_value)
453
+ with jax.named_scope(kernel_name):
454
+ out = kernel(x, w_q, w_scale, x_abs_max)
455
+
456
+ return out[:orig_n_batch, :orig_n_out]