tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,54 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Common utilities for GMM kernels."""
15
+
16
+ import re
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+
21
+
22
+ def is_tpu() -> bool:
23
+ return "TPU" in jax.devices()[0].device_kind
24
+
25
+
26
+ def tpu_kind() -> str:
27
+ """Query identification string for the currently attached TPU."""
28
+ return jax.devices()[0].device_kind
29
+
30
+
31
+ # Most TPU devices follow the pattern "TPU v{version}{variant}", e.g. "TPU v5p"
32
+ # TPU v7 has a different pattern (i.e. "TPU7x")
33
+ _TPU_KIND_PATTERN = re.compile(r"TPU( v)?(\d+)")
34
+
35
+
36
+ def tpu_generation() -> int:
37
+ """Generation number of the currently attached TPU."""
38
+ if version := _TPU_KIND_PATTERN.match(tpu_kind()):
39
+ return int(version[2])
40
+ raise NotImplementedError("only TPU devices are supported")
41
+
42
+
43
+ def assert_is_supported_dtype(dtype: jnp.dtype) -> None:
44
+ if dtype not in [
45
+ jnp.bfloat16,
46
+ jnp.float32,
47
+ jnp.float8_e4m3fn,
48
+ jnp.float8_e5m2,
49
+ jnp.int8,
50
+ jnp.int4,
51
+ jnp.float4_e2m1fn,
52
+ jnp.uint4,
53
+ ]:
54
+ raise ValueError(f"No support for {dtype=}.")
@@ -0,0 +1,646 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Grouped matrix multiplication kernels for TPU written in Pallas."""
15
+
16
+ import functools
17
+ from collections.abc import Callable
18
+ from typing import Any, Optional
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ from jax import lax
23
+ from jax.experimental import pallas as pl
24
+ from jax.experimental.pallas import tpu as pltpu
25
+
26
+ from tpu_inference.kernels.megablox import common
27
+
28
+ partial = functools.partial
29
+
30
+
31
+ def _validate_args(
32
+ *,
33
+ lhs: jnp.ndarray,
34
+ rhs: jnp.ndarray,
35
+ group_sizes: jnp.ndarray,
36
+ rhs_scale: jnp.ndarray | None = None,
37
+ rhs_bias: jnp.ndarray | None = None,
38
+ ):
39
+ """Validates the arguments for the gmm function."""
40
+ # Validate 'lhs'.
41
+ if lhs.ndim != 2:
42
+ raise ValueError(f"Expected 2-tensor for 'lhs' but got {lhs.ndim=}.")
43
+ common.assert_is_supported_dtype(lhs.dtype)
44
+
45
+ # Validate 'rhs'.
46
+ if rhs.ndim != 3:
47
+ raise ValueError(f"Expected 3-tensor for 'rhs' but got {rhs.ndim=}.")
48
+ common.assert_is_supported_dtype(rhs.dtype)
49
+
50
+ if lhs.shape[1] != rhs.shape[2]:
51
+ raise ValueError(
52
+ "Expected 'lhs' and 'rhs' to have the same number of input features."
53
+ f" But instead got {lhs.shape[1]=} and {rhs.shape[2]=}")
54
+
55
+ # Validate 'group_sizes'.
56
+ if group_sizes.dtype != jnp.int32:
57
+ raise ValueError(
58
+ f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype=}."
59
+ )
60
+
61
+ num_groups, out_size, in_size = rhs.shape
62
+
63
+ if rhs_scale is not None:
64
+ # Validate 'rhs_scale'.
65
+ if rhs_scale.ndim != 4:
66
+ raise ValueError(
67
+ f"Expected 4-tensor for 'rhs_scale' but got {rhs_scale.ndim=}."
68
+ )
69
+ expected_rhs_scale_shape = (num_groups, rhs_scale.shape[1], 1,
70
+ out_size)
71
+ if rhs_scale.shape != expected_rhs_scale_shape:
72
+ raise ValueError(
73
+ "Expected 'rhs_scale' to have the shape of"
74
+ f" {expected_rhs_scale_shape} but got {rhs_scale.shape=}.")
75
+
76
+ if rhs_bias is not None:
77
+ # Validate 'rhs_bias'.
78
+ if rhs_bias.ndim != 3:
79
+ raise ValueError(
80
+ f"Expected 3-tensor for 'rhs_bias' but got {rhs_bias.ndim=}.")
81
+ expected_rhs_bias_shape = (num_groups, 1, out_size)
82
+ if rhs_bias.shape != expected_rhs_bias_shape:
83
+ raise ValueError(
84
+ "Expected 'rhs_bias' to have the shape of"
85
+ f" {expected_rhs_bias_shape} but got {rhs_bias.shape=}.")
86
+
87
+
88
+ def _calculate_num_tiles(x: int, tx: int) -> int:
89
+ tiles, rem = divmod(x, tx)
90
+ if rem:
91
+ raise ValueError(
92
+ f"{x} must be divisible by x-dimension tile size ({tx}).")
93
+ return tiles
94
+
95
+
96
+ def _calculate_irregular_num_tiles(x: int, tx: int) -> tuple[int, int]:
97
+ tiles, rem = divmod(x, tx)
98
+ if rem:
99
+ tiles += 1
100
+ return tiles, rem
101
+
102
+
103
+ GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple
104
+
105
+
106
+ def make_group_metadata(
107
+ *,
108
+ group_sizes: jnp.ndarray,
109
+ m: int,
110
+ tm: int,
111
+ start_group: jnp.ndarray,
112
+ num_nonzero_groups: int,
113
+ visit_empty_groups: bool = True,
114
+ ) -> GroupMetadata:
115
+ """Create the metadata needed for grouped matmul computation.
116
+
117
+ Args:
118
+ group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
119
+ m: The number of rows in lhs.
120
+ tm: The m-dimension tile size being used.
121
+ start_group: The group in group sizes to start computing from. This is
122
+ particularly useful for when rhs num_groups is sharded.
123
+ num_nonzero_groups: Number of groups in group sizes to compute on. Useful in
124
+ combination with group_offset.
125
+ visit_empty_groups: If True, do not squeeze tiles for empty groups out of
126
+ the metadata. This is necessary for tgmm, where we at least need to zero
127
+ the output for each group.
128
+
129
+ Returns:
130
+ tuple of:
131
+ group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32
132
+ dtype. group_offsets[i] indicates the row at which group [i] starts in
133
+ the lhs matrix and group_offsets[i-1] = m.
134
+ group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and
135
+ jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will
136
+ work on.
137
+ m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and
138
+ jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i'
139
+ will work on.
140
+ num_tiles: The number of m-dimension tiles to execute.
141
+ """
142
+ num_groups = group_sizes.shape[0]
143
+ end_group = start_group + num_nonzero_groups - 1
144
+
145
+ # Calculate the offset of each group, starting at zero. This metadata is
146
+ # similar to row offsets in a CSR matrix. The following properties hold:
147
+ #
148
+ # group_offsets.shape = [num_groups + 1]
149
+ # group_offsets[0] = 0
150
+ # group_offsets[num_groups] = m
151
+ #
152
+ # The row at which group 'i' starts is group_offsets[i].
153
+ group_ends = jnp.cumsum(group_sizes)
154
+ group_offsets = jnp.concatenate(
155
+ [jnp.zeros(1, dtype=jnp.int32), group_ends])
156
+
157
+ # Assign a group id to each grid index.
158
+ #
159
+ # If a group starts somewhere other than the start of a tile or ends somewhere
160
+ # other than the end of a tile we need to compute that full tile. Calculate
161
+ # the number of tiles for each group by rounding their end up to the nearest
162
+ # 'tm' and their start down to the nearest 'tm'.
163
+
164
+ # (1) Round the group_ends up to the nearest multiple of 'tm'.
165
+ #
166
+ # NOTE: This does not change group_offsets[num_groups], which is m
167
+ # (because we enforce m is divisible by tm).
168
+ rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32)
169
+
170
+ # (2) Round the group_starts down to the nearest multiple of 'tm'.
171
+ group_starts = jnp.concatenate(
172
+ [jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]])
173
+ rounded_group_starts = group_starts // tm * tm
174
+
175
+ # (3) Calculate the number of rows in each group.
176
+ #
177
+ # NOTE: Handle zero-sized groups as a special case. If the start for a
178
+ # zero-sized group is not divisible by 'tm' its start will be rounded down and
179
+ # its end will be rounded up such that its size will become 1 tile here.
180
+ rounded_group_sizes = rounded_group_ends - rounded_group_starts
181
+ rounded_group_sizes = jnp.where(group_sizes == 0, 0, rounded_group_sizes)
182
+
183
+ # (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles.
184
+ #
185
+ # An m-dimension tile is 'owned' by group 'i' if the first row of the tile
186
+ # belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1
187
+ # initial partial tiles if it's first row does not occur in the first row of a
188
+ # tile. The '0-th' group never has a partial tile because it always starts at
189
+ # the 0-th row.
190
+ #
191
+ # If no group has a partial tile, the total number of tiles is equal to
192
+ # 'm // tm'. If every group has a partial except the 0-th group, the total
193
+ # number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that
194
+ #
195
+ # tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1
196
+ #
197
+ # Where tiles_m = m // tm.
198
+ #
199
+ # NOTE: All group sizes are divisible by 'tm' because of the rounding in steps
200
+ # (1) and (2) so this division is exact.
201
+ group_tiles = rounded_group_sizes // tm
202
+
203
+ if visit_empty_groups:
204
+ # Insert one tile for empty groups.
205
+ group_tiles = jnp.where(group_sizes == 0, 1, group_tiles)
206
+
207
+ # Create the group ids for each grid index based on the tile counts for each
208
+ # group.
209
+ #
210
+ # NOTE: This repeat(...) will pad group_ids with the final group id if
211
+ # group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized
212
+ # such that we only execute the necessary number of tiles.
213
+ tiles_m = _calculate_num_tiles(m, tm)
214
+ group_ids = jnp.repeat(
215
+ jnp.arange(num_groups, dtype=jnp.int32),
216
+ group_tiles,
217
+ total_repeat_length=tiles_m + num_groups - 1,
218
+ )
219
+
220
+ # Assign an m-dimension tile id to each grid index.
221
+ #
222
+ # NOTE: Output tiles can only be re-visited consecutively. The following
223
+ # procedure guarantees that m-dimension tile indices respect this.
224
+
225
+ # (1) Calculate how many times each m-dimension tile will be visited.
226
+ #
227
+ # Each tile is guaranteed to be visited once by the group that owns the tile.
228
+ # The remaining possible visits occur when a group starts inside of a tile at
229
+ # a position other than the first row. We can calculate which m-dimension tile
230
+ # each group starts in by floor-dividing its offset with `tm` and then count
231
+ # tile visits with a histogram.
232
+ #
233
+ # To avoid double counting tile visits from the group that owns the tile,
234
+ # filter these out by assigning their tile id to `tile_m` (one beyond the max)
235
+ # such that they're ignored by the subsequent histogram. Also filter out any
236
+ # group which is empty.
237
+ #
238
+ # TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear.
239
+ partial_tile_mask = jnp.logical_or((group_offsets[:-1] % tm) == 0,
240
+ group_sizes == 0)
241
+
242
+ # Explicitly enable tiles for zero sized groups, if specified. This covers
243
+ # zero sized groups that start on a tile-aligned row and those that do not.
244
+ if visit_empty_groups:
245
+ partial_tile_mask = jnp.where(group_sizes == 0, 0, partial_tile_mask)
246
+
247
+ partial_tile_ids = jnp.where(partial_tile_mask, tiles_m,
248
+ group_offsets[:-1] // tm)
249
+
250
+ tile_visits = (jnp.histogram(
251
+ partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + 1)
252
+
253
+ # Create the m-dimension tile ids for each grid index based on the visit
254
+ # counts for each tile.
255
+ m_tile_ids = jnp.repeat(
256
+ jnp.arange(tiles_m, dtype=jnp.int32),
257
+ tile_visits.astype(jnp.int32),
258
+ total_repeat_length=tiles_m + num_groups - 1,
259
+ )
260
+
261
+ # Account for sharding.
262
+ #
263
+ # Find the start of the groups owned by our shard and shift the group_ids and
264
+ # m_tile_ids s.t. the metadata for our tiles are at the front of the arrays.
265
+ #
266
+ # TODO(tgale): Move this offset into the kernel to avoid these rolls.
267
+ first_tile_in_shard = (group_ids < start_group).sum()
268
+ group_ids = jnp.roll(group_ids, shift=-first_tile_in_shard, axis=0)
269
+ m_tile_ids = jnp.roll(m_tile_ids, shift=-first_tile_in_shard, axis=0)
270
+
271
+ # Calculate the number of tiles we need to compute for our shard.
272
+ #
273
+ # Remove tile visits that belong to a group not in our shard.
274
+ iota = jnp.arange(num_groups, dtype=jnp.int32)
275
+ active_group_mask = jnp.logical_and(iota <= end_group, iota >= start_group)
276
+ group_tiles = jnp.where(active_group_mask, group_tiles, 0)
277
+ num_tiles = group_tiles.sum()
278
+ return (group_offsets, group_ids, m_tile_ids), num_tiles
279
+
280
+
281
+ def _get_store_mask(
282
+ *,
283
+ grid_id: jnp.ndarray,
284
+ group_metadata: GroupMetadata,
285
+ tm: int,
286
+ tn: int,
287
+ ) -> jnp.ndarray:
288
+ """Mask for rows that belong to the current group in the current tile."""
289
+ group_offsets, group_ids, m_tile_ids = group_metadata[:3]
290
+ group_id = group_ids[grid_id]
291
+ group_start = group_offsets[group_id]
292
+ group_end = group_offsets[group_id + 1]
293
+ m_id = m_tile_ids[grid_id] * tm
294
+ iota = jax.lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id
295
+ return jnp.logical_and(iota >= group_start, iota < group_end)
296
+
297
+
298
+ def _zero_uninitialized_memory(
299
+ out: jnp.ndarray,
300
+ *,
301
+ start_group: jnp.ndarray,
302
+ num_nonzero_groups: int,
303
+ group_metadata: GroupMetadata,
304
+ ) -> jnp.ndarray:
305
+ """Zero out uninitialized memory from output."""
306
+ group_offsets = group_metadata[0]
307
+ group_start = group_offsets[start_group]
308
+ group_end = group_offsets[start_group + num_nonzero_groups]
309
+ valid_mask = jax.lax.broadcasted_iota(jnp.int32, (out.shape[0], ), 0)
310
+ valid_mask = (valid_mask >= group_start) & (valid_mask < group_end)
311
+ return jnp.where(valid_mask[:, None], out, 0)
312
+
313
+
314
+ LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]]
315
+
316
+
317
+ @functools.partial(
318
+ jax.jit,
319
+ static_argnames=[
320
+ "preferred_element_type",
321
+ "tiling",
322
+ "transpose_rhs",
323
+ "interpret",
324
+ ],
325
+ )
326
+ def gmm(
327
+ lhs: jnp.ndarray,
328
+ rhs: jnp.ndarray,
329
+ group_sizes: jnp.ndarray,
330
+ preferred_element_type: jnp.dtype = jnp.float32,
331
+ rhs_scale: jnp.ndarray | None = None,
332
+ rhs_bias: jnp.ndarray | None = None,
333
+ tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128),
334
+ group_offset: jnp.ndarray | None = None,
335
+ existing_out: jnp.ndarray | None = None,
336
+ transpose_rhs: bool = False,
337
+ interpret: bool = False,
338
+ ) -> jnp.ndarray:
339
+ """Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
340
+
341
+ Args:
342
+ lhs: A 2d, jnp.ndarray with shape [m, k].
343
+ rhs: A 3d, jnp.ndarray with shape [num_groups, n, k].
344
+ group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
345
+ preferred_element_type: jnp.dtype, the element type for the output matrix.
346
+ rhs_scale: A 4d, jnp.ndarray with shape [num_groups, num_blocks, 1, n].
347
+ rhs_bias: A 3d, jnp.ndarray with shape [num_groups, 1, n].
348
+ tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
349
+ group_offset: The group in group sizes to start computing from. This is
350
+ particularly useful for when rhs num_groups is sharded.
351
+ existing_out: Existing output to write to.
352
+ transpose_rhs: True if the rhs needs to be transposed.
353
+ interpret: Whether or not to run the kernel in interpret mode, helpful for
354
+ testing and debugging.
355
+
356
+ Returns:
357
+ A 2d, jnp.ndarray with shape [m, n].
358
+ """
359
+
360
+ # TODO(kyuyeunk): Instead of transpose_rhs==True, modify logic to only
361
+ # transpose_rhs==False instead as it simplifies the logic in kernel.
362
+ assert transpose_rhs
363
+
364
+ if existing_out is not None:
365
+ assert isinstance(existing_out, jax.Array)
366
+ expected_dtype = existing_out.dtype
367
+ if expected_dtype != preferred_element_type:
368
+ raise ValueError(
369
+ "Existing output dtype must match preferred_element_type.")
370
+ if group_offset is None:
371
+ group_offset = jnp.array([0], dtype=jnp.int32)
372
+ else:
373
+ if group_offset.shape:
374
+ raise ValueError(
375
+ f"group_offset must be a ()-shaped array. Got: {group_offset.shape}."
376
+ )
377
+ group_offset = group_offset[None]
378
+ num_current_groups = rhs.shape[0]
379
+ num_total_groups = group_sizes.shape[0]
380
+ _validate_args(
381
+ lhs=lhs,
382
+ rhs=rhs,
383
+ group_sizes=group_sizes,
384
+ rhs_scale=rhs_scale,
385
+ rhs_bias=rhs_bias,
386
+ )
387
+
388
+ # Gather shape information.
389
+ m, k, n = (lhs.shape[0], lhs.shape[1], rhs.shape[1])
390
+
391
+ # If tiling is callable, look up the problem dimensions in the LUT. If no
392
+ # tuned tile dimensions are available throw an error.
393
+ if callable(tiling):
394
+ tiling = tiling(m, k, n)
395
+
396
+ if tiling is None:
397
+ raise ValueError(
398
+ f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})")
399
+
400
+ tm, tk, tn = tiling
401
+
402
+ if rhs_scale is not None:
403
+ assert isinstance(rhs_scale, jax.Array)
404
+ assert rhs_scale.shape[0] == num_current_groups
405
+ num_quant_blocks = rhs_scale.shape[1]
406
+ else:
407
+ num_quant_blocks = 1
408
+ block_size = k // num_quant_blocks
409
+
410
+ if tk > block_size or block_size % tk != 0:
411
+ tk = block_size
412
+
413
+ tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk)
414
+ tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn)
415
+ del n_rem
416
+
417
+ tiles_k //= num_quant_blocks
418
+
419
+ # Create the metadata we need for computation.
420
+ group_metadata, num_active_tiles = make_group_metadata( # pylint: disable=unbalanced-tuple-unpacking
421
+ group_sizes=group_sizes,
422
+ m=m,
423
+ tm=tm,
424
+ start_group=group_offset[0],
425
+ num_nonzero_groups=rhs.shape[0],
426
+ visit_empty_groups=False,
427
+ )
428
+
429
+ def kernel(
430
+ group_metadata,
431
+ group_offset,
432
+ lhs,
433
+ rhs,
434
+ rhs_scale,
435
+ rhs_bias,
436
+ existing_out,
437
+ out,
438
+ acc_scratch,
439
+ ):
440
+ group_offsets, group_ids, m_tile_ids = group_metadata
441
+ del group_offsets, group_ids, group_offset
442
+
443
+ grid_id = pl.program_id(1)
444
+ b_i = pl.program_id(2)
445
+ k_i = pl.program_id(3)
446
+
447
+ @pl.when(k_i == 0)
448
+ def _zero_acc():
449
+ acc_scratch[...] = jnp.zeros_like(acc_scratch)
450
+
451
+ if existing_out is not None:
452
+ prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0)
453
+ is_first_processed_group = grid_id == 0
454
+ m_tile_changed = m_tile_ids[grid_id] != m_tile_ids[
455
+ prev_grid_id]
456
+ first_time_seeing_out = jnp.logical_or(
457
+ is_first_processed_group, m_tile_changed)
458
+
459
+ @pl.when(first_time_seeing_out)
460
+ def _init_out():
461
+ out[...] = existing_out[...]
462
+
463
+ def mask_k_rem(x, *, dim):
464
+ if k_rem == 0:
465
+ return x
466
+
467
+ orig_dtype = x.dtype
468
+ iota = lax.broadcasted_iota(jnp.int32, x.shape, dim)
469
+ x = x.astype(jnp.float32)
470
+ return jnp.where(iota < k_rem, x, 0).astype(orig_dtype)
471
+
472
+ def _accum(is_last_k_tile, is_first_b_tile):
473
+ if is_last_k_tile:
474
+ mask_k_rem_lhs = partial(mask_k_rem, dim=1)
475
+ mask_k_rem_rhs = partial(mask_k_rem, dim=1)
476
+ else:
477
+
478
+ def _wrapper(x):
479
+ return x
480
+
481
+ mask_k_rem_lhs = _wrapper
482
+ mask_k_rem_rhs = _wrapper
483
+
484
+ loaded_lhs = lhs[...]
485
+ loaded_rhs = rhs[...]
486
+
487
+ acc = acc_scratch[...] + jax.lax.dot_general(
488
+ mask_k_rem_lhs(loaded_lhs),
489
+ mask_k_rem_rhs(loaded_rhs),
490
+ preferred_element_type=jnp.float32,
491
+ dimension_numbers=(((1, ), (1, )), ((), ())),
492
+ )
493
+
494
+ if is_last_k_tile:
495
+ if rhs_scale is not None:
496
+ acc *= jnp.broadcast_to(rhs_scale[...], acc.shape)
497
+
498
+ loaded_out = out[...].astype(jnp.float32)
499
+ if not is_first_b_tile:
500
+ acc += loaded_out
501
+ elif rhs_bias is not None:
502
+ acc += rhs_bias[...].astype(jnp.float32)
503
+
504
+ mask = _get_store_mask(
505
+ grid_id=grid_id,
506
+ group_metadata=group_metadata,
507
+ tm=tm,
508
+ tn=tn,
509
+ )
510
+ out[...] = jax.lax.select(
511
+ mask[...], acc, loaded_out).astype(preferred_element_type)
512
+ else:
513
+ acc_scratch[...] = acc
514
+
515
+ is_last_k_tile = k_i == (tiles_k - 1)
516
+ is_first_b_tile = b_i == 0
517
+
518
+ lax.cond(
519
+ is_last_k_tile,
520
+ lambda: lax.cond(
521
+ is_first_b_tile,
522
+ partial(_accum, True, True),
523
+ partial(_accum, True, False),
524
+ ),
525
+ partial(_accum, False, False),
526
+ )
527
+
528
+ def lhs_transform_indices(n_i, grid_id, b_i, k_i, group_metadata,
529
+ group_offset):
530
+ # lhs is (m, k). Load the [tm, tk] matrix for this m-tile.
531
+ group_offsets, group_ids, m_tile_ids = group_metadata
532
+ del n_i, group_offsets, group_ids, group_offset
533
+ return m_tile_ids[grid_id], b_i * tiles_k + k_i
534
+
535
+ def rhs_transform_indices(n_i, grid_id, b_i, k_i, group_metadata,
536
+ group_offset):
537
+ # rhs is (num_groups, k, n). Load the [tk, tn] matrix based on the group id
538
+ # for this m-tile.
539
+ group_offsets, group_ids, m_tile_ids = group_metadata
540
+ del group_offsets, m_tile_ids
541
+
542
+ # NOTE: If we're working on only a shard of the rhs we need to adjust the
543
+ # group index we load from to account for this. The group_ids are in the
544
+ # "unsharded" domain.
545
+ return group_ids[grid_id] - group_offset[0], n_i, b_i * tiles_k + k_i
546
+
547
+ def rhs_scale_transform_indices(n_i, grid_id, b_i, k_i, group_metadata,
548
+ group_offset):
549
+ group_offsets, group_ids, m_tile_ids = group_metadata
550
+ del group_offsets, m_tile_ids, k_i
551
+ return group_ids[grid_id] - group_offset[0], b_i, 0, n_i
552
+
553
+ def rhs_bias_transform_indices(n_i, grid_id, b_i, k_i, group_metadata,
554
+ group_offset):
555
+ group_offsets, group_ids, m_tile_ids = group_metadata
556
+ del group_offsets, m_tile_ids, k_i, b_i
557
+ return group_ids[grid_id] - group_offset[0], 0, n_i
558
+
559
+ def out_transform_indices(n_i, grid_id, b_i, k_i, group_metadata,
560
+ group_offset):
561
+ # out is (m, n). Load the [tm, tn] matrix for this m-tile.
562
+ group_offsets, group_ids, m_tile_ids = group_metadata
563
+ del k_i, group_offsets, group_ids, group_offset, b_i
564
+ return m_tile_ids[grid_id], n_i
565
+
566
+ out_block_spec = pl.BlockSpec((tm, tn), out_transform_indices)
567
+ if existing_out is None:
568
+ in_out_block_spec: Any = None
569
+ input_output_aliases = {}
570
+ else:
571
+ in_out_block_spec = out_block_spec
572
+ input_output_aliases = {7: 0}
573
+
574
+ lhs_block_spec = pl.BlockSpec((tm, tk), lhs_transform_indices)
575
+ rhs_block_spec = pl.BlockSpec((None, tn, tk), rhs_transform_indices)
576
+
577
+ if rhs_scale is None:
578
+ rhs_scale_block_spec = None
579
+ else:
580
+ rhs_scale_block_spec = pl.BlockSpec((None, None, 1, tn),
581
+ rhs_scale_transform_indices)
582
+
583
+ if rhs_bias is None:
584
+ rhs_bias_block_spec = None
585
+ else:
586
+ rhs_bias_block_spec = pl.BlockSpec((None, 1, tn),
587
+ rhs_bias_transform_indices)
588
+
589
+ lhs_bytes = lhs.size * lhs.itemsize
590
+ rhs_bytes = (k * n) * rhs.itemsize # We don't read all of rhs
591
+ if rhs_scale is not None:
592
+ rhs_bytes += (num_quant_blocks * n) * rhs_scale.itemsize
593
+ if rhs_bias is not None:
594
+ rhs_bytes += n * rhs_bias.itemsize
595
+ out_bytes = (m * n) * jnp.dtype(preferred_element_type).itemsize
596
+ max_active_tiles = group_metadata[1].size
597
+ bytes_accessed = ((lhs_bytes * tiles_n) + (rhs_bytes * max_active_tiles) +
598
+ out_bytes)
599
+ flops = 2 * m * k * n
600
+ cost_estimate = pl.CostEstimate(flops=flops,
601
+ bytes_accessed=bytes_accessed,
602
+ transcendentals=0)
603
+ call_gmm = pl.pallas_call(
604
+ kernel,
605
+ out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type),
606
+ grid_spec=pltpu.PrefetchScalarGridSpec(
607
+ num_scalar_prefetch=2,
608
+ in_specs=[
609
+ lhs_block_spec,
610
+ rhs_block_spec,
611
+ rhs_scale_block_spec,
612
+ rhs_bias_block_spec,
613
+ in_out_block_spec,
614
+ ],
615
+ out_specs=out_block_spec,
616
+ grid=(tiles_n, num_active_tiles, num_quant_blocks, tiles_k),
617
+ scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)],
618
+ ),
619
+ input_output_aliases=input_output_aliases,
620
+ compiler_params=pltpu.CompilerParams(dimension_semantics=(
621
+ "parallel",
622
+ "arbitrary",
623
+ "arbitrary",
624
+ "arbitrary",
625
+ )),
626
+ interpret=interpret,
627
+ cost_estimate=cost_estimate,
628
+ )
629
+
630
+ out = call_gmm(
631
+ group_metadata,
632
+ group_offset,
633
+ lhs,
634
+ rhs,
635
+ rhs_scale,
636
+ rhs_bias,
637
+ existing_out,
638
+ )
639
+ if existing_out is None and num_current_groups < num_total_groups:
640
+ out = _zero_uninitialized_memory(
641
+ out,
642
+ start_group=group_offset[0],
643
+ num_nonzero_groups=rhs.shape[0],
644
+ group_metadata=group_metadata,
645
+ )
646
+ return out