tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
File without changes
@@ -0,0 +1,395 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Quantized matmul kernel."""
3
+
4
+ import functools
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from jax._src import dtypes
9
+ from jax.experimental import pallas as pl
10
+ from jax.experimental.pallas import tpu as pltpu
11
+
12
+ from tpu_inference.kernels.quantized_matmul.tuned_block_sizes import (
13
+ TunedValue, get_device_vmem_limit, get_tuned_block_sizes)
14
+ from tpu_inference.kernels.quantized_matmul.util import (get_kernel_name,
15
+ next_multiple,
16
+ unfold_args)
17
+
18
+
19
+ def quantize_array(
20
+ x: jax.Array, # [bs_block_size, in_block_size]
21
+ x_abs_max: jax.Array, # [1, bs_block_size]
22
+ quant_dtype: jnp.dtype,
23
+ ):
24
+ is_float = jnp.issubdtype(quant_dtype, jnp.floating)
25
+ dtype_info = jnp.finfo(quant_dtype) if is_float else jnp.iinfo(quant_dtype)
26
+ dtype_max = float(dtype_info.max)
27
+
28
+ # TODO(kyuyeunk): Investigate performance gain from non xlu transpose.
29
+ scale = jnp.transpose(x_abs_max / dtype_max)
30
+ return (x / scale).astype(quant_dtype), scale.astype(jnp.float32)
31
+
32
+
33
+ def get_vmem_limit(
34
+ n_batch: int,
35
+ n_out: int,
36
+ n_in: int,
37
+ batch_block_size: int,
38
+ out_block_size: int,
39
+ in_block_size: int,
40
+ x_dtype: jnp.dtype,
41
+ x_q_dtype: jnp.dtype,
42
+ w_q_dtype: jnp.dtype,
43
+ scale_dtype: jnp.dtype,
44
+ out_dtype: jnp.dtype,
45
+ acc_dtype: jnp.dtype,
46
+ save_acc: bool,
47
+ save_x_q: bool,
48
+ upper_limit_bytes: int,
49
+ ):
50
+ """Calculate VMEM limit for the kernel."""
51
+
52
+ # Calculate in/out VMEM size.
53
+ x_size = batch_block_size * in_block_size * dtypes.bit_width(x_dtype)
54
+ x_abs_max_size = batch_block_size * dtypes.bit_width(scale_dtype)
55
+ w_q_size = out_block_size * in_block_size * dtypes.bit_width(w_q_dtype)
56
+ w_scale_size = out_block_size * dtypes.bit_width(scale_dtype)
57
+ out_size = batch_block_size * out_block_size * dtypes.bit_width(out_dtype)
58
+
59
+ vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
60
+ vmem_in_out *= 2 # Account for compute and vreg spills.
61
+
62
+ # Account for double buffering.
63
+ # Double buffering is used only if there are multiple blocks per in/out.
64
+ vmem_in_out += x_size if (n_batch > 1 or n_in > 1) else 0
65
+ vmem_in_out += x_abs_max_size if (n_batch > 1) else 0
66
+ vmem_in_out += w_q_size if (n_out > 1 or n_in > 1) else 0
67
+ vmem_in_out += w_scale_size if (n_out > 1) else 0
68
+ vmem_in_out += out_size if (n_batch > 1 or n_out > 1) else 0
69
+
70
+ # Calculate scratch VMEM size.
71
+ acc_size = batch_block_size * out_block_size * dtypes.bit_width(acc_dtype)
72
+ x_q_size = batch_block_size * in_block_size * dtypes.bit_width(x_q_dtype)
73
+ x_scale_size = batch_block_size * dtypes.bit_width(scale_dtype)
74
+
75
+ vmem_scratch = acc_size if save_acc else 0
76
+ vmem_scratch += x_q_size + x_scale_size if save_x_q else 0
77
+ vmem_scratch *= 2 # Account for compute and vreg spills.
78
+
79
+ # Add in/out and scratch VMEM size.
80
+ vmem_used = vmem_in_out + vmem_scratch
81
+ vmem_used_bytes = vmem_used // 8 # Convert bits to bytes.
82
+ # Specify upper limit. Defaults to 96MB.
83
+ vmem_limit_bytes = min(vmem_used_bytes, upper_limit_bytes)
84
+
85
+ return vmem_limit_bytes
86
+
87
+
88
+ def validate_inputs(
89
+ x: jax.Array,
90
+ w_q: jax.Array,
91
+ w_scale: jax.Array,
92
+ x_abs_max: jax.Array,
93
+ x_q_dtype: jnp.dtype,
94
+ batch_block_size: int,
95
+ out_block_size: int,
96
+ in_block_size: int,
97
+ ):
98
+ """Verify inputs invoking the kernel."""
99
+
100
+ if x.dtype != x_q_dtype:
101
+ # If the input is quantized, then it should be the same subdtype as w_q
102
+ if jnp.issubdtype(x_q_dtype, jnp.integer) != jnp.issubdtype(
103
+ w_q.dtype, jnp.integer):
104
+ raise ValueError(
105
+ f'{x_q_dtype=} and {w_q.dtype=} must be the same int or float type.'
106
+ )
107
+
108
+ # Verify input shapes.
109
+ if x.shape[1] != w_q.shape[1]:
110
+ raise ValueError(f'{x.shape[1]=} must be equal to {w_q.shape[1]=}')
111
+ if w_q.shape[0] != w_scale.shape[1]:
112
+ raise ValueError(
113
+ f'{w_q.shape[0]=} must be equal to {w_scale.shape[1]=}')
114
+ if x_abs_max.shape != (1, x.shape[0]):
115
+ raise ValueError(
116
+ f'{x_abs_max.shape=} must be equal to (1, {x.shape[0]=})')
117
+ if x.shape[0] % batch_block_size != 0:
118
+ raise ValueError(
119
+ f'{x.shape[0]=} must be a multiple of {batch_block_size=}')
120
+ if w_q.shape[0] % out_block_size != 0:
121
+ raise ValueError(
122
+ f'{w_q.shape[0]=} must be a multiple of {out_block_size=}')
123
+ if x.shape[1] % in_block_size != 0:
124
+ raise ValueError(
125
+ f'{x.shape[1]=} must be a multiple of {in_block_size=}')
126
+
127
+
128
+ def matmul_kernel(
129
+ x_ref: jax.Array, # (batch_block_size, in_block_size)
130
+ w_q_ref: jax.Array, # (out_block_size, in_block_size)
131
+ w_scale_ref: jax.Array, # (1, out_block_size)
132
+ x_abs_max_ref: jax.Array, # (1, batch_block_size)
133
+ out_ref: jax.Array, # (batch_block_size, out_block_size)
134
+ acc_scratch: jax.Array, # (batch_block_size, out_block_size)
135
+ x_q_scratch: jax.Array, # (batch_block_size, in_block_size)
136
+ x_scale_scratch: jax.Array, # (batch_block_size, 1)
137
+ *,
138
+ x_q_dtype: jnp.dtype,
139
+ save_acc: bool,
140
+ save_x_q: bool,
141
+ ):
142
+ out_idx, in_idx = pl.program_id(1), pl.program_id(2)
143
+ n_in = pl.num_programs(2)
144
+ x_ref_dtype = x_ref.dtype
145
+
146
+ quantize_activation = x_q_dtype != x_ref_dtype
147
+
148
+ # Initialize conditional logic.
149
+ if save_x_q:
150
+ assert quantize_activation
151
+ assert x_q_scratch is not None
152
+ assert x_scale_scratch is not None
153
+ quant = out_idx == 0
154
+ else:
155
+ assert x_q_scratch is None
156
+ assert x_scale_scratch is None
157
+ quant = quantize_activation
158
+
159
+ if save_acc:
160
+ assert acc_scratch is not None
161
+ is_first_step = in_idx == 0
162
+ is_last_step = in_idx == (n_in - 1)
163
+ else:
164
+ assert acc_scratch is None
165
+ is_first_step = True
166
+ is_last_step = True
167
+
168
+ acc_dtype = jnp.float32
169
+ if quantize_activation and jnp.issubdtype(w_q_ref.dtype, jnp.integer):
170
+ acc_dtype = jnp.int32
171
+
172
+ # Start of actual computation logic.
173
+ def matmul_body(quant: bool, is_first_step: bool, is_last_step: bool):
174
+ if quantize_activation:
175
+ if quant:
176
+ x_q_tmp, x_scale_tmp = quantize_array(
177
+ x_ref[...],
178
+ x_abs_max_ref[...],
179
+ x_q_dtype,
180
+ )
181
+
182
+ if save_x_q:
183
+ x_q_scratch[...] = x_q_tmp
184
+ x_scale_scratch[...] = x_scale_tmp
185
+
186
+ else:
187
+ assert save_x_q
188
+ x_q_tmp = x_q_scratch[...]
189
+ if is_last_step:
190
+ x_scale_tmp = x_scale_scratch[...]
191
+
192
+ acc = jax.lax.dot_general(
193
+ x_q_tmp,
194
+ w_q_ref[...],
195
+ (((1, ), (1, )), ((), ())),
196
+ preferred_element_type=acc_dtype,
197
+ )
198
+ else:
199
+ acc = jax.lax.dot_general(
200
+ x_ref[...],
201
+ w_q_ref[...],
202
+ (((1, ), (1, )), ((), ())),
203
+ preferred_element_type=acc_dtype,
204
+ )
205
+
206
+ if not is_first_step:
207
+ acc += acc_scratch[...]
208
+
209
+ if is_last_step:
210
+ acc *= w_scale_ref[...]
211
+ if quantize_activation:
212
+ # TODO(kyuyeunk): Investigate caching broadcast.
213
+ acc *= x_scale_tmp
214
+ out_ref[...] = acc.astype(x_ref_dtype)
215
+ else:
216
+ assert save_acc
217
+ acc_scratch[...] = acc
218
+
219
+ unfold_args((quant, is_first_step, is_last_step), (), matmul_body)
220
+
221
+
222
+ @functools.partial(
223
+ jax.jit,
224
+ static_argnames=[
225
+ 'x_q_dtype',
226
+ 'tuned_value',
227
+ ],
228
+ )
229
+ def quantized_matmul_kernel(
230
+ x: jax.Array, # [bs, n_in]
231
+ w_q: jax.Array, # [n_out, n_in]
232
+ w_scale: jax.Array, # [n_out]
233
+ w_zp: jax.Array | None = None, # [n_out]
234
+ block_size: int | None = None,
235
+ x_q_dtype: jnp.dtype | None = None,
236
+ *,
237
+ tuned_value: TunedValue | None = None,
238
+ ) -> jax.Array:
239
+ """Quantized matmul kernel.
240
+
241
+ Args:
242
+ x: Input unquantized array.
243
+ w_q: Weight quantized array. [n_output_features, n_input_features]
244
+ w_scale: Weight quantization scale. [n_output_features]
245
+ w_zp: Weight zero point for asymmetric quantization.
246
+ block_size: Block size for subchannel quantization.
247
+ x_q_dtype: Quantization type of the input. If None or if the value is the
248
+ same as x.dtype, then no quantization is applied.
249
+ tuned_value: Kernel tuned values for optimal performance.
250
+
251
+ Returns:
252
+ Quantized matmul result.
253
+ """
254
+
255
+ if w_zp is not None:
256
+ raise NotImplementedError('zero_point is not supported.')
257
+ if block_size is not None:
258
+ raise NotImplementedError('block_size is not supported.')
259
+
260
+ if x_q_dtype is None:
261
+ x_q_dtype = x.dtype
262
+ quantize_activation = x_q_dtype != x.dtype
263
+
264
+ # Pallas kernel only has access to a single block of the input. Therefere,
265
+ # for per-token quantization, abs max has to be computed outside of the
266
+ # kernel.
267
+ x_abs_max = jnp.max(jnp.abs(x), axis=-1, keepdims=False) # [bs]
268
+ # Pallas requires minormost dim to be a multiple of sublane size 128.
269
+ # Therefore, instead of using [bs, 1], we reshape this into [1, bs]
270
+ x_abs_max = jnp.expand_dims(x_abs_max, axis=0) # [1, bs]
271
+ assert x_abs_max.shape == (1, x.shape[0])
272
+
273
+ orig_n_batch, orig_n_in = x.shape
274
+ orig_n_out, _ = w_q.shape
275
+
276
+ if tuned_value is None:
277
+ tuned_value = get_tuned_block_sizes(
278
+ n_batch=orig_n_batch,
279
+ n_out=orig_n_out,
280
+ n_in=orig_n_in,
281
+ x_q_dtype=jnp.dtype(x_q_dtype).name,
282
+ w_q_dtype=jnp.dtype(w_q.dtype).name,
283
+ )
284
+ batch_block_size = tuned_value.batch_block_size
285
+ out_block_size = tuned_value.out_block_size
286
+ in_block_size = tuned_value.in_block_size
287
+
288
+ # Pad the inputs to be multiple of block size.
289
+ padded_n_batch = next_multiple(orig_n_batch, batch_block_size)
290
+ if orig_n_batch < padded_n_batch:
291
+ x = jnp.pad(x, ((0, padded_n_batch - orig_n_batch), (0, 0)))
292
+ x_abs_max = jnp.pad(x_abs_max,
293
+ ((0, 0), (0, padded_n_batch - orig_n_batch)))
294
+ padded_n_out = next_multiple(orig_n_out, out_block_size)
295
+ if orig_n_out < padded_n_out:
296
+ w_q = jnp.pad(w_q, ((0, padded_n_out - orig_n_out), (0, 0)))
297
+ w_scale = jnp.pad(w_scale, (0, padded_n_out - orig_n_out))
298
+ padded_n_in = next_multiple(orig_n_in, in_block_size)
299
+ if orig_n_in < padded_n_in:
300
+ x = jnp.pad(x, ((0, 0), (0, padded_n_in - orig_n_in)))
301
+ w_q = jnp.pad(w_q, ((0, 0), (0, padded_n_in - orig_n_in)))
302
+
303
+ if w_scale.dtype != jnp.float32:
304
+ w_scale = w_scale.astype(jnp.float32)
305
+ w_scale = jnp.expand_dims(w_scale, axis=0) # [1, n_output_features]
306
+
307
+ n_batch = padded_n_batch // batch_block_size
308
+ n_out = padded_n_out // out_block_size
309
+ n_in = padded_n_in // in_block_size
310
+
311
+ save_acc = n_in > 1
312
+ # Remove redundant input quantization logic by caching quantized input. For
313
+ # best performance, only enable this behavior when single input block is
314
+ # used per batch.
315
+ save_x_q = quantize_activation and n_in == 1 and n_out > 1
316
+
317
+ acc_dtype = jnp.float32
318
+ if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
319
+ acc_dtype = jnp.int32
320
+
321
+ vmem_limit_bytes = get_vmem_limit(
322
+ n_batch=n_batch,
323
+ n_out=n_out,
324
+ n_in=n_in,
325
+ batch_block_size=batch_block_size,
326
+ out_block_size=out_block_size,
327
+ in_block_size=in_block_size,
328
+ x_dtype=x.dtype,
329
+ x_q_dtype=x_q_dtype,
330
+ w_q_dtype=w_q.dtype,
331
+ scale_dtype=jnp.float32,
332
+ out_dtype=x.dtype,
333
+ acc_dtype=acc_dtype,
334
+ save_acc=save_acc,
335
+ save_x_q=save_x_q,
336
+ upper_limit_bytes=get_device_vmem_limit(),
337
+ )
338
+
339
+ kernel = pl.pallas_call(
340
+ functools.partial(
341
+ matmul_kernel,
342
+ x_q_dtype=x_q_dtype,
343
+ save_acc=save_acc,
344
+ save_x_q=save_x_q,
345
+ ),
346
+ grid_spec=pltpu.PrefetchScalarGridSpec(
347
+ num_scalar_prefetch=0,
348
+ in_specs=[
349
+ pl.BlockSpec((batch_block_size, in_block_size), lambda b, o, i:
350
+ (b, i)), # x
351
+ pl.BlockSpec((out_block_size, in_block_size), lambda b, o, i:
352
+ (o, i)), # w_q
353
+ pl.BlockSpec((1, out_block_size), lambda b, o, i:
354
+ (0, o)), # w_scale
355
+ pl.BlockSpec((1, batch_block_size), lambda b, o, i:
356
+ (0, b)), # x_abs_max
357
+ ],
358
+ out_specs=pl.BlockSpec((batch_block_size, out_block_size),
359
+ lambda b, o, i: (b, o)),
360
+ scratch_shapes=[
361
+ pltpu.VMEM((batch_block_size, out_block_size), acc_dtype)
362
+ if save_acc else None, # acc_scratch
363
+ pltpu.VMEM((batch_block_size, in_block_size), x_q_dtype)
364
+ if save_x_q else None, # x_q_scratch
365
+ pltpu.VMEM(
366
+ (batch_block_size,
367
+ 1), jnp.float32) if save_x_q else None, # x_scale_scratch
368
+ ],
369
+ grid=(n_batch, n_out, n_in),
370
+ ),
371
+ out_shape=jax.ShapeDtypeStruct((padded_n_batch, padded_n_out),
372
+ x.dtype),
373
+ compiler_params=pltpu.CompilerParams(
374
+ dimension_semantics=('parallel', 'arbitrary', 'arbitrary'),
375
+ vmem_limit_bytes=vmem_limit_bytes,
376
+ ),
377
+ )
378
+
379
+ validate_inputs(
380
+ x=x,
381
+ w_q=w_q,
382
+ w_scale=w_scale,
383
+ x_abs_max=x_abs_max,
384
+ x_q_dtype=x_q_dtype,
385
+ batch_block_size=batch_block_size,
386
+ out_block_size=out_block_size,
387
+ in_block_size=in_block_size,
388
+ )
389
+
390
+ # The named_scope is used for autotune.
391
+ kernel_name = get_kernel_name(tuned_value)
392
+ with jax.named_scope(kernel_name):
393
+ out = kernel(x, w_q, w_scale, x_abs_max)
394
+
395
+ return out[:orig_n_batch, :orig_n_out]