tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,646 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Grouped matrix multiplication kernels for TPU written in Pallas."""
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from typing import Any, Optional
|
|
19
|
+
|
|
20
|
+
import jax
|
|
21
|
+
import jax.numpy as jnp
|
|
22
|
+
from jax import lax
|
|
23
|
+
from jax.experimental import pallas as pl
|
|
24
|
+
from jax.experimental.pallas import tpu as pltpu
|
|
25
|
+
|
|
26
|
+
from tpu_inference.kernels.megablox import common
|
|
27
|
+
|
|
28
|
+
partial = functools.partial
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _validate_args(
|
|
32
|
+
*,
|
|
33
|
+
lhs: jnp.ndarray,
|
|
34
|
+
rhs: jnp.ndarray,
|
|
35
|
+
group_sizes: jnp.ndarray,
|
|
36
|
+
rhs_scale: jnp.ndarray | None = None,
|
|
37
|
+
rhs_bias: jnp.ndarray | None = None,
|
|
38
|
+
):
|
|
39
|
+
"""Validates the arguments for the gmm function."""
|
|
40
|
+
# Validate 'lhs'.
|
|
41
|
+
if lhs.ndim != 2:
|
|
42
|
+
raise ValueError(f"Expected 2-tensor for 'lhs' but got {lhs.ndim=}.")
|
|
43
|
+
common.assert_is_supported_dtype(lhs.dtype)
|
|
44
|
+
|
|
45
|
+
# Validate 'rhs'.
|
|
46
|
+
if rhs.ndim != 3:
|
|
47
|
+
raise ValueError(f"Expected 3-tensor for 'rhs' but got {rhs.ndim=}.")
|
|
48
|
+
common.assert_is_supported_dtype(rhs.dtype)
|
|
49
|
+
|
|
50
|
+
if lhs.shape[1] != rhs.shape[2]:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"Expected 'lhs' and 'rhs' to have the same number of input features."
|
|
53
|
+
f" But instead got {lhs.shape[1]=} and {rhs.shape[2]=}")
|
|
54
|
+
|
|
55
|
+
# Validate 'group_sizes'.
|
|
56
|
+
if group_sizes.dtype != jnp.int32:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype=}."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
num_groups, out_size, in_size = rhs.shape
|
|
62
|
+
|
|
63
|
+
if rhs_scale is not None:
|
|
64
|
+
# Validate 'rhs_scale'.
|
|
65
|
+
if rhs_scale.ndim != 4:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Expected 4-tensor for 'rhs_scale' but got {rhs_scale.ndim=}."
|
|
68
|
+
)
|
|
69
|
+
expected_rhs_scale_shape = (num_groups, rhs_scale.shape[1], 1,
|
|
70
|
+
out_size)
|
|
71
|
+
if rhs_scale.shape != expected_rhs_scale_shape:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"Expected 'rhs_scale' to have the shape of"
|
|
74
|
+
f" {expected_rhs_scale_shape} but got {rhs_scale.shape=}.")
|
|
75
|
+
|
|
76
|
+
if rhs_bias is not None:
|
|
77
|
+
# Validate 'rhs_bias'.
|
|
78
|
+
if rhs_bias.ndim != 3:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Expected 3-tensor for 'rhs_bias' but got {rhs_bias.ndim=}.")
|
|
81
|
+
expected_rhs_bias_shape = (num_groups, 1, out_size)
|
|
82
|
+
if rhs_bias.shape != expected_rhs_bias_shape:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Expected 'rhs_bias' to have the shape of"
|
|
85
|
+
f" {expected_rhs_bias_shape} but got {rhs_bias.shape=}.")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _calculate_num_tiles(x: int, tx: int) -> int:
|
|
89
|
+
tiles, rem = divmod(x, tx)
|
|
90
|
+
if rem:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"{x} must be divisible by x-dimension tile size ({tx}).")
|
|
93
|
+
return tiles
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _calculate_irregular_num_tiles(x: int, tx: int) -> tuple[int, int]:
|
|
97
|
+
tiles, rem = divmod(x, tx)
|
|
98
|
+
if rem:
|
|
99
|
+
tiles += 1
|
|
100
|
+
return tiles, rem
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def make_group_metadata(
|
|
107
|
+
*,
|
|
108
|
+
group_sizes: jnp.ndarray,
|
|
109
|
+
m: int,
|
|
110
|
+
tm: int,
|
|
111
|
+
start_group: jnp.ndarray,
|
|
112
|
+
num_nonzero_groups: int,
|
|
113
|
+
visit_empty_groups: bool = True,
|
|
114
|
+
) -> GroupMetadata:
|
|
115
|
+
"""Create the metadata needed for grouped matmul computation.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
|
|
119
|
+
m: The number of rows in lhs.
|
|
120
|
+
tm: The m-dimension tile size being used.
|
|
121
|
+
start_group: The group in group sizes to start computing from. This is
|
|
122
|
+
particularly useful for when rhs num_groups is sharded.
|
|
123
|
+
num_nonzero_groups: Number of groups in group sizes to compute on. Useful in
|
|
124
|
+
combination with group_offset.
|
|
125
|
+
visit_empty_groups: If True, do not squeeze tiles for empty groups out of
|
|
126
|
+
the metadata. This is necessary for tgmm, where we at least need to zero
|
|
127
|
+
the output for each group.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
tuple of:
|
|
131
|
+
group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32
|
|
132
|
+
dtype. group_offsets[i] indicates the row at which group [i] starts in
|
|
133
|
+
the lhs matrix and group_offsets[i-1] = m.
|
|
134
|
+
group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and
|
|
135
|
+
jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will
|
|
136
|
+
work on.
|
|
137
|
+
m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and
|
|
138
|
+
jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i'
|
|
139
|
+
will work on.
|
|
140
|
+
num_tiles: The number of m-dimension tiles to execute.
|
|
141
|
+
"""
|
|
142
|
+
num_groups = group_sizes.shape[0]
|
|
143
|
+
end_group = start_group + num_nonzero_groups - 1
|
|
144
|
+
|
|
145
|
+
# Calculate the offset of each group, starting at zero. This metadata is
|
|
146
|
+
# similar to row offsets in a CSR matrix. The following properties hold:
|
|
147
|
+
#
|
|
148
|
+
# group_offsets.shape = [num_groups + 1]
|
|
149
|
+
# group_offsets[0] = 0
|
|
150
|
+
# group_offsets[num_groups] = m
|
|
151
|
+
#
|
|
152
|
+
# The row at which group 'i' starts is group_offsets[i].
|
|
153
|
+
group_ends = jnp.cumsum(group_sizes)
|
|
154
|
+
group_offsets = jnp.concatenate(
|
|
155
|
+
[jnp.zeros(1, dtype=jnp.int32), group_ends])
|
|
156
|
+
|
|
157
|
+
# Assign a group id to each grid index.
|
|
158
|
+
#
|
|
159
|
+
# If a group starts somewhere other than the start of a tile or ends somewhere
|
|
160
|
+
# other than the end of a tile we need to compute that full tile. Calculate
|
|
161
|
+
# the number of tiles for each group by rounding their end up to the nearest
|
|
162
|
+
# 'tm' and their start down to the nearest 'tm'.
|
|
163
|
+
|
|
164
|
+
# (1) Round the group_ends up to the nearest multiple of 'tm'.
|
|
165
|
+
#
|
|
166
|
+
# NOTE: This does not change group_offsets[num_groups], which is m
|
|
167
|
+
# (because we enforce m is divisible by tm).
|
|
168
|
+
rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32)
|
|
169
|
+
|
|
170
|
+
# (2) Round the group_starts down to the nearest multiple of 'tm'.
|
|
171
|
+
group_starts = jnp.concatenate(
|
|
172
|
+
[jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]])
|
|
173
|
+
rounded_group_starts = group_starts // tm * tm
|
|
174
|
+
|
|
175
|
+
# (3) Calculate the number of rows in each group.
|
|
176
|
+
#
|
|
177
|
+
# NOTE: Handle zero-sized groups as a special case. If the start for a
|
|
178
|
+
# zero-sized group is not divisible by 'tm' its start will be rounded down and
|
|
179
|
+
# its end will be rounded up such that its size will become 1 tile here.
|
|
180
|
+
rounded_group_sizes = rounded_group_ends - rounded_group_starts
|
|
181
|
+
rounded_group_sizes = jnp.where(group_sizes == 0, 0, rounded_group_sizes)
|
|
182
|
+
|
|
183
|
+
# (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles.
|
|
184
|
+
#
|
|
185
|
+
# An m-dimension tile is 'owned' by group 'i' if the first row of the tile
|
|
186
|
+
# belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1
|
|
187
|
+
# initial partial tiles if it's first row does not occur in the first row of a
|
|
188
|
+
# tile. The '0-th' group never has a partial tile because it always starts at
|
|
189
|
+
# the 0-th row.
|
|
190
|
+
#
|
|
191
|
+
# If no group has a partial tile, the total number of tiles is equal to
|
|
192
|
+
# 'm // tm'. If every group has a partial except the 0-th group, the total
|
|
193
|
+
# number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that
|
|
194
|
+
#
|
|
195
|
+
# tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1
|
|
196
|
+
#
|
|
197
|
+
# Where tiles_m = m // tm.
|
|
198
|
+
#
|
|
199
|
+
# NOTE: All group sizes are divisible by 'tm' because of the rounding in steps
|
|
200
|
+
# (1) and (2) so this division is exact.
|
|
201
|
+
group_tiles = rounded_group_sizes // tm
|
|
202
|
+
|
|
203
|
+
if visit_empty_groups:
|
|
204
|
+
# Insert one tile for empty groups.
|
|
205
|
+
group_tiles = jnp.where(group_sizes == 0, 1, group_tiles)
|
|
206
|
+
|
|
207
|
+
# Create the group ids for each grid index based on the tile counts for each
|
|
208
|
+
# group.
|
|
209
|
+
#
|
|
210
|
+
# NOTE: This repeat(...) will pad group_ids with the final group id if
|
|
211
|
+
# group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized
|
|
212
|
+
# such that we only execute the necessary number of tiles.
|
|
213
|
+
tiles_m = _calculate_num_tiles(m, tm)
|
|
214
|
+
group_ids = jnp.repeat(
|
|
215
|
+
jnp.arange(num_groups, dtype=jnp.int32),
|
|
216
|
+
group_tiles,
|
|
217
|
+
total_repeat_length=tiles_m + num_groups - 1,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Assign an m-dimension tile id to each grid index.
|
|
221
|
+
#
|
|
222
|
+
# NOTE: Output tiles can only be re-visited consecutively. The following
|
|
223
|
+
# procedure guarantees that m-dimension tile indices respect this.
|
|
224
|
+
|
|
225
|
+
# (1) Calculate how many times each m-dimension tile will be visited.
|
|
226
|
+
#
|
|
227
|
+
# Each tile is guaranteed to be visited once by the group that owns the tile.
|
|
228
|
+
# The remaining possible visits occur when a group starts inside of a tile at
|
|
229
|
+
# a position other than the first row. We can calculate which m-dimension tile
|
|
230
|
+
# each group starts in by floor-dividing its offset with `tm` and then count
|
|
231
|
+
# tile visits with a histogram.
|
|
232
|
+
#
|
|
233
|
+
# To avoid double counting tile visits from the group that owns the tile,
|
|
234
|
+
# filter these out by assigning their tile id to `tile_m` (one beyond the max)
|
|
235
|
+
# such that they're ignored by the subsequent histogram. Also filter out any
|
|
236
|
+
# group which is empty.
|
|
237
|
+
#
|
|
238
|
+
# TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear.
|
|
239
|
+
partial_tile_mask = jnp.logical_or((group_offsets[:-1] % tm) == 0,
|
|
240
|
+
group_sizes == 0)
|
|
241
|
+
|
|
242
|
+
# Explicitly enable tiles for zero sized groups, if specified. This covers
|
|
243
|
+
# zero sized groups that start on a tile-aligned row and those that do not.
|
|
244
|
+
if visit_empty_groups:
|
|
245
|
+
partial_tile_mask = jnp.where(group_sizes == 0, 0, partial_tile_mask)
|
|
246
|
+
|
|
247
|
+
partial_tile_ids = jnp.where(partial_tile_mask, tiles_m,
|
|
248
|
+
group_offsets[:-1] // tm)
|
|
249
|
+
|
|
250
|
+
tile_visits = (jnp.histogram(
|
|
251
|
+
partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + 1)
|
|
252
|
+
|
|
253
|
+
# Create the m-dimension tile ids for each grid index based on the visit
|
|
254
|
+
# counts for each tile.
|
|
255
|
+
m_tile_ids = jnp.repeat(
|
|
256
|
+
jnp.arange(tiles_m, dtype=jnp.int32),
|
|
257
|
+
tile_visits.astype(jnp.int32),
|
|
258
|
+
total_repeat_length=tiles_m + num_groups - 1,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# Account for sharding.
|
|
262
|
+
#
|
|
263
|
+
# Find the start of the groups owned by our shard and shift the group_ids and
|
|
264
|
+
# m_tile_ids s.t. the metadata for our tiles are at the front of the arrays.
|
|
265
|
+
#
|
|
266
|
+
# TODO(tgale): Move this offset into the kernel to avoid these rolls.
|
|
267
|
+
first_tile_in_shard = (group_ids < start_group).sum()
|
|
268
|
+
group_ids = jnp.roll(group_ids, shift=-first_tile_in_shard, axis=0)
|
|
269
|
+
m_tile_ids = jnp.roll(m_tile_ids, shift=-first_tile_in_shard, axis=0)
|
|
270
|
+
|
|
271
|
+
# Calculate the number of tiles we need to compute for our shard.
|
|
272
|
+
#
|
|
273
|
+
# Remove tile visits that belong to a group not in our shard.
|
|
274
|
+
iota = jnp.arange(num_groups, dtype=jnp.int32)
|
|
275
|
+
active_group_mask = jnp.logical_and(iota <= end_group, iota >= start_group)
|
|
276
|
+
group_tiles = jnp.where(active_group_mask, group_tiles, 0)
|
|
277
|
+
num_tiles = group_tiles.sum()
|
|
278
|
+
return (group_offsets, group_ids, m_tile_ids), num_tiles
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _get_store_mask(
|
|
282
|
+
*,
|
|
283
|
+
grid_id: jnp.ndarray,
|
|
284
|
+
group_metadata: GroupMetadata,
|
|
285
|
+
tm: int,
|
|
286
|
+
tn: int,
|
|
287
|
+
) -> jnp.ndarray:
|
|
288
|
+
"""Mask for rows that belong to the current group in the current tile."""
|
|
289
|
+
group_offsets, group_ids, m_tile_ids = group_metadata[:3]
|
|
290
|
+
group_id = group_ids[grid_id]
|
|
291
|
+
group_start = group_offsets[group_id]
|
|
292
|
+
group_end = group_offsets[group_id + 1]
|
|
293
|
+
m_id = m_tile_ids[grid_id] * tm
|
|
294
|
+
iota = jax.lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id
|
|
295
|
+
return jnp.logical_and(iota >= group_start, iota < group_end)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def _zero_uninitialized_memory(
|
|
299
|
+
out: jnp.ndarray,
|
|
300
|
+
*,
|
|
301
|
+
start_group: jnp.ndarray,
|
|
302
|
+
num_nonzero_groups: int,
|
|
303
|
+
group_metadata: GroupMetadata,
|
|
304
|
+
) -> jnp.ndarray:
|
|
305
|
+
"""Zero out uninitialized memory from output."""
|
|
306
|
+
group_offsets = group_metadata[0]
|
|
307
|
+
group_start = group_offsets[start_group]
|
|
308
|
+
group_end = group_offsets[start_group + num_nonzero_groups]
|
|
309
|
+
valid_mask = jax.lax.broadcasted_iota(jnp.int32, (out.shape[0], ), 0)
|
|
310
|
+
valid_mask = (valid_mask >= group_start) & (valid_mask < group_end)
|
|
311
|
+
return jnp.where(valid_mask[:, None], out, 0)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]]
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
@functools.partial(
|
|
318
|
+
jax.jit,
|
|
319
|
+
static_argnames=[
|
|
320
|
+
"preferred_element_type",
|
|
321
|
+
"tiling",
|
|
322
|
+
"transpose_rhs",
|
|
323
|
+
"interpret",
|
|
324
|
+
],
|
|
325
|
+
)
|
|
326
|
+
def gmm(
|
|
327
|
+
lhs: jnp.ndarray,
|
|
328
|
+
rhs: jnp.ndarray,
|
|
329
|
+
group_sizes: jnp.ndarray,
|
|
330
|
+
preferred_element_type: jnp.dtype = jnp.float32,
|
|
331
|
+
rhs_scale: jnp.ndarray | None = None,
|
|
332
|
+
rhs_bias: jnp.ndarray | None = None,
|
|
333
|
+
tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128),
|
|
334
|
+
group_offset: jnp.ndarray | None = None,
|
|
335
|
+
existing_out: jnp.ndarray | None = None,
|
|
336
|
+
transpose_rhs: bool = False,
|
|
337
|
+
interpret: bool = False,
|
|
338
|
+
) -> jnp.ndarray:
|
|
339
|
+
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
lhs: A 2d, jnp.ndarray with shape [m, k].
|
|
343
|
+
rhs: A 3d, jnp.ndarray with shape [num_groups, n, k].
|
|
344
|
+
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
|
|
345
|
+
preferred_element_type: jnp.dtype, the element type for the output matrix.
|
|
346
|
+
rhs_scale: A 4d, jnp.ndarray with shape [num_groups, num_blocks, 1, n].
|
|
347
|
+
rhs_bias: A 3d, jnp.ndarray with shape [num_groups, 1, n].
|
|
348
|
+
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
|
|
349
|
+
group_offset: The group in group sizes to start computing from. This is
|
|
350
|
+
particularly useful for when rhs num_groups is sharded.
|
|
351
|
+
existing_out: Existing output to write to.
|
|
352
|
+
transpose_rhs: True if the rhs needs to be transposed.
|
|
353
|
+
interpret: Whether or not to run the kernel in interpret mode, helpful for
|
|
354
|
+
testing and debugging.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
A 2d, jnp.ndarray with shape [m, n].
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
# TODO(kyuyeunk): Instead of transpose_rhs==True, modify logic to only
|
|
361
|
+
# transpose_rhs==False instead as it simplifies the logic in kernel.
|
|
362
|
+
assert transpose_rhs
|
|
363
|
+
|
|
364
|
+
if existing_out is not None:
|
|
365
|
+
assert isinstance(existing_out, jax.Array)
|
|
366
|
+
expected_dtype = existing_out.dtype
|
|
367
|
+
if expected_dtype != preferred_element_type:
|
|
368
|
+
raise ValueError(
|
|
369
|
+
"Existing output dtype must match preferred_element_type.")
|
|
370
|
+
if group_offset is None:
|
|
371
|
+
group_offset = jnp.array([0], dtype=jnp.int32)
|
|
372
|
+
else:
|
|
373
|
+
if group_offset.shape:
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"group_offset must be a ()-shaped array. Got: {group_offset.shape}."
|
|
376
|
+
)
|
|
377
|
+
group_offset = group_offset[None]
|
|
378
|
+
num_current_groups = rhs.shape[0]
|
|
379
|
+
num_total_groups = group_sizes.shape[0]
|
|
380
|
+
_validate_args(
|
|
381
|
+
lhs=lhs,
|
|
382
|
+
rhs=rhs,
|
|
383
|
+
group_sizes=group_sizes,
|
|
384
|
+
rhs_scale=rhs_scale,
|
|
385
|
+
rhs_bias=rhs_bias,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Gather shape information.
|
|
389
|
+
m, k, n = (lhs.shape[0], lhs.shape[1], rhs.shape[1])
|
|
390
|
+
|
|
391
|
+
# If tiling is callable, look up the problem dimensions in the LUT. If no
|
|
392
|
+
# tuned tile dimensions are available throw an error.
|
|
393
|
+
if callable(tiling):
|
|
394
|
+
tiling = tiling(m, k, n)
|
|
395
|
+
|
|
396
|
+
if tiling is None:
|
|
397
|
+
raise ValueError(
|
|
398
|
+
f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})")
|
|
399
|
+
|
|
400
|
+
tm, tk, tn = tiling
|
|
401
|
+
|
|
402
|
+
if rhs_scale is not None:
|
|
403
|
+
assert isinstance(rhs_scale, jax.Array)
|
|
404
|
+
assert rhs_scale.shape[0] == num_current_groups
|
|
405
|
+
num_quant_blocks = rhs_scale.shape[1]
|
|
406
|
+
else:
|
|
407
|
+
num_quant_blocks = 1
|
|
408
|
+
block_size = k // num_quant_blocks
|
|
409
|
+
|
|
410
|
+
if tk > block_size or block_size % tk != 0:
|
|
411
|
+
tk = block_size
|
|
412
|
+
|
|
413
|
+
tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk)
|
|
414
|
+
tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn)
|
|
415
|
+
del n_rem
|
|
416
|
+
|
|
417
|
+
tiles_k //= num_quant_blocks
|
|
418
|
+
|
|
419
|
+
# Create the metadata we need for computation.
|
|
420
|
+
group_metadata, num_active_tiles = make_group_metadata( # pylint: disable=unbalanced-tuple-unpacking
|
|
421
|
+
group_sizes=group_sizes,
|
|
422
|
+
m=m,
|
|
423
|
+
tm=tm,
|
|
424
|
+
start_group=group_offset[0],
|
|
425
|
+
num_nonzero_groups=rhs.shape[0],
|
|
426
|
+
visit_empty_groups=False,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
def kernel(
|
|
430
|
+
group_metadata,
|
|
431
|
+
group_offset,
|
|
432
|
+
lhs,
|
|
433
|
+
rhs,
|
|
434
|
+
rhs_scale,
|
|
435
|
+
rhs_bias,
|
|
436
|
+
existing_out,
|
|
437
|
+
out,
|
|
438
|
+
acc_scratch,
|
|
439
|
+
):
|
|
440
|
+
group_offsets, group_ids, m_tile_ids = group_metadata
|
|
441
|
+
del group_offsets, group_ids, group_offset
|
|
442
|
+
|
|
443
|
+
grid_id = pl.program_id(1)
|
|
444
|
+
b_i = pl.program_id(2)
|
|
445
|
+
k_i = pl.program_id(3)
|
|
446
|
+
|
|
447
|
+
@pl.when(k_i == 0)
|
|
448
|
+
def _zero_acc():
|
|
449
|
+
acc_scratch[...] = jnp.zeros_like(acc_scratch)
|
|
450
|
+
|
|
451
|
+
if existing_out is not None:
|
|
452
|
+
prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0)
|
|
453
|
+
is_first_processed_group = grid_id == 0
|
|
454
|
+
m_tile_changed = m_tile_ids[grid_id] != m_tile_ids[
|
|
455
|
+
prev_grid_id]
|
|
456
|
+
first_time_seeing_out = jnp.logical_or(
|
|
457
|
+
is_first_processed_group, m_tile_changed)
|
|
458
|
+
|
|
459
|
+
@pl.when(first_time_seeing_out)
|
|
460
|
+
def _init_out():
|
|
461
|
+
out[...] = existing_out[...]
|
|
462
|
+
|
|
463
|
+
def mask_k_rem(x, *, dim):
|
|
464
|
+
if k_rem == 0:
|
|
465
|
+
return x
|
|
466
|
+
|
|
467
|
+
orig_dtype = x.dtype
|
|
468
|
+
iota = lax.broadcasted_iota(jnp.int32, x.shape, dim)
|
|
469
|
+
x = x.astype(jnp.float32)
|
|
470
|
+
return jnp.where(iota < k_rem, x, 0).astype(orig_dtype)
|
|
471
|
+
|
|
472
|
+
def _accum(is_last_k_tile, is_first_b_tile):
|
|
473
|
+
if is_last_k_tile:
|
|
474
|
+
mask_k_rem_lhs = partial(mask_k_rem, dim=1)
|
|
475
|
+
mask_k_rem_rhs = partial(mask_k_rem, dim=1)
|
|
476
|
+
else:
|
|
477
|
+
|
|
478
|
+
def _wrapper(x):
|
|
479
|
+
return x
|
|
480
|
+
|
|
481
|
+
mask_k_rem_lhs = _wrapper
|
|
482
|
+
mask_k_rem_rhs = _wrapper
|
|
483
|
+
|
|
484
|
+
loaded_lhs = lhs[...]
|
|
485
|
+
loaded_rhs = rhs[...]
|
|
486
|
+
|
|
487
|
+
acc = acc_scratch[...] + jax.lax.dot_general(
|
|
488
|
+
mask_k_rem_lhs(loaded_lhs),
|
|
489
|
+
mask_k_rem_rhs(loaded_rhs),
|
|
490
|
+
preferred_element_type=jnp.float32,
|
|
491
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
if is_last_k_tile:
|
|
495
|
+
if rhs_scale is not None:
|
|
496
|
+
acc *= jnp.broadcast_to(rhs_scale[...], acc.shape)
|
|
497
|
+
|
|
498
|
+
loaded_out = out[...].astype(jnp.float32)
|
|
499
|
+
if not is_first_b_tile:
|
|
500
|
+
acc += loaded_out
|
|
501
|
+
elif rhs_bias is not None:
|
|
502
|
+
acc += rhs_bias[...].astype(jnp.float32)
|
|
503
|
+
|
|
504
|
+
mask = _get_store_mask(
|
|
505
|
+
grid_id=grid_id,
|
|
506
|
+
group_metadata=group_metadata,
|
|
507
|
+
tm=tm,
|
|
508
|
+
tn=tn,
|
|
509
|
+
)
|
|
510
|
+
out[...] = jax.lax.select(
|
|
511
|
+
mask[...], acc, loaded_out).astype(preferred_element_type)
|
|
512
|
+
else:
|
|
513
|
+
acc_scratch[...] = acc
|
|
514
|
+
|
|
515
|
+
is_last_k_tile = k_i == (tiles_k - 1)
|
|
516
|
+
is_first_b_tile = b_i == 0
|
|
517
|
+
|
|
518
|
+
lax.cond(
|
|
519
|
+
is_last_k_tile,
|
|
520
|
+
lambda: lax.cond(
|
|
521
|
+
is_first_b_tile,
|
|
522
|
+
partial(_accum, True, True),
|
|
523
|
+
partial(_accum, True, False),
|
|
524
|
+
),
|
|
525
|
+
partial(_accum, False, False),
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
def lhs_transform_indices(n_i, grid_id, b_i, k_i, group_metadata,
|
|
529
|
+
group_offset):
|
|
530
|
+
# lhs is (m, k). Load the [tm, tk] matrix for this m-tile.
|
|
531
|
+
group_offsets, group_ids, m_tile_ids = group_metadata
|
|
532
|
+
del n_i, group_offsets, group_ids, group_offset
|
|
533
|
+
return m_tile_ids[grid_id], b_i * tiles_k + k_i
|
|
534
|
+
|
|
535
|
+
def rhs_transform_indices(n_i, grid_id, b_i, k_i, group_metadata,
|
|
536
|
+
group_offset):
|
|
537
|
+
# rhs is (num_groups, k, n). Load the [tk, tn] matrix based on the group id
|
|
538
|
+
# for this m-tile.
|
|
539
|
+
group_offsets, group_ids, m_tile_ids = group_metadata
|
|
540
|
+
del group_offsets, m_tile_ids
|
|
541
|
+
|
|
542
|
+
# NOTE: If we're working on only a shard of the rhs we need to adjust the
|
|
543
|
+
# group index we load from to account for this. The group_ids are in the
|
|
544
|
+
# "unsharded" domain.
|
|
545
|
+
return group_ids[grid_id] - group_offset[0], n_i, b_i * tiles_k + k_i
|
|
546
|
+
|
|
547
|
+
def rhs_scale_transform_indices(n_i, grid_id, b_i, k_i, group_metadata,
|
|
548
|
+
group_offset):
|
|
549
|
+
group_offsets, group_ids, m_tile_ids = group_metadata
|
|
550
|
+
del group_offsets, m_tile_ids, k_i
|
|
551
|
+
return group_ids[grid_id] - group_offset[0], b_i, 0, n_i
|
|
552
|
+
|
|
553
|
+
def rhs_bias_transform_indices(n_i, grid_id, b_i, k_i, group_metadata,
|
|
554
|
+
group_offset):
|
|
555
|
+
group_offsets, group_ids, m_tile_ids = group_metadata
|
|
556
|
+
del group_offsets, m_tile_ids, k_i, b_i
|
|
557
|
+
return group_ids[grid_id] - group_offset[0], 0, n_i
|
|
558
|
+
|
|
559
|
+
def out_transform_indices(n_i, grid_id, b_i, k_i, group_metadata,
|
|
560
|
+
group_offset):
|
|
561
|
+
# out is (m, n). Load the [tm, tn] matrix for this m-tile.
|
|
562
|
+
group_offsets, group_ids, m_tile_ids = group_metadata
|
|
563
|
+
del k_i, group_offsets, group_ids, group_offset, b_i
|
|
564
|
+
return m_tile_ids[grid_id], n_i
|
|
565
|
+
|
|
566
|
+
out_block_spec = pl.BlockSpec((tm, tn), out_transform_indices)
|
|
567
|
+
if existing_out is None:
|
|
568
|
+
in_out_block_spec: Any = None
|
|
569
|
+
input_output_aliases = {}
|
|
570
|
+
else:
|
|
571
|
+
in_out_block_spec = out_block_spec
|
|
572
|
+
input_output_aliases = {7: 0}
|
|
573
|
+
|
|
574
|
+
lhs_block_spec = pl.BlockSpec((tm, tk), lhs_transform_indices)
|
|
575
|
+
rhs_block_spec = pl.BlockSpec((None, tn, tk), rhs_transform_indices)
|
|
576
|
+
|
|
577
|
+
if rhs_scale is None:
|
|
578
|
+
rhs_scale_block_spec = None
|
|
579
|
+
else:
|
|
580
|
+
rhs_scale_block_spec = pl.BlockSpec((None, None, 1, tn),
|
|
581
|
+
rhs_scale_transform_indices)
|
|
582
|
+
|
|
583
|
+
if rhs_bias is None:
|
|
584
|
+
rhs_bias_block_spec = None
|
|
585
|
+
else:
|
|
586
|
+
rhs_bias_block_spec = pl.BlockSpec((None, 1, tn),
|
|
587
|
+
rhs_bias_transform_indices)
|
|
588
|
+
|
|
589
|
+
lhs_bytes = lhs.size * lhs.itemsize
|
|
590
|
+
rhs_bytes = (k * n) * rhs.itemsize # We don't read all of rhs
|
|
591
|
+
if rhs_scale is not None:
|
|
592
|
+
rhs_bytes += (num_quant_blocks * n) * rhs_scale.itemsize
|
|
593
|
+
if rhs_bias is not None:
|
|
594
|
+
rhs_bytes += n * rhs_bias.itemsize
|
|
595
|
+
out_bytes = (m * n) * jnp.dtype(preferred_element_type).itemsize
|
|
596
|
+
max_active_tiles = group_metadata[1].size
|
|
597
|
+
bytes_accessed = ((lhs_bytes * tiles_n) + (rhs_bytes * max_active_tiles) +
|
|
598
|
+
out_bytes)
|
|
599
|
+
flops = 2 * m * k * n
|
|
600
|
+
cost_estimate = pl.CostEstimate(flops=flops,
|
|
601
|
+
bytes_accessed=bytes_accessed,
|
|
602
|
+
transcendentals=0)
|
|
603
|
+
call_gmm = pl.pallas_call(
|
|
604
|
+
kernel,
|
|
605
|
+
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type),
|
|
606
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
607
|
+
num_scalar_prefetch=2,
|
|
608
|
+
in_specs=[
|
|
609
|
+
lhs_block_spec,
|
|
610
|
+
rhs_block_spec,
|
|
611
|
+
rhs_scale_block_spec,
|
|
612
|
+
rhs_bias_block_spec,
|
|
613
|
+
in_out_block_spec,
|
|
614
|
+
],
|
|
615
|
+
out_specs=out_block_spec,
|
|
616
|
+
grid=(tiles_n, num_active_tiles, num_quant_blocks, tiles_k),
|
|
617
|
+
scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)],
|
|
618
|
+
),
|
|
619
|
+
input_output_aliases=input_output_aliases,
|
|
620
|
+
compiler_params=pltpu.CompilerParams(dimension_semantics=(
|
|
621
|
+
"parallel",
|
|
622
|
+
"arbitrary",
|
|
623
|
+
"arbitrary",
|
|
624
|
+
"arbitrary",
|
|
625
|
+
)),
|
|
626
|
+
interpret=interpret,
|
|
627
|
+
cost_estimate=cost_estimate,
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
out = call_gmm(
|
|
631
|
+
group_metadata,
|
|
632
|
+
group_offset,
|
|
633
|
+
lhs,
|
|
634
|
+
rhs,
|
|
635
|
+
rhs_scale,
|
|
636
|
+
rhs_bias,
|
|
637
|
+
existing_out,
|
|
638
|
+
)
|
|
639
|
+
if existing_out is None and num_current_groups < num_total_groups:
|
|
640
|
+
out = _zero_uninitialized_memory(
|
|
641
|
+
out,
|
|
642
|
+
start_group=group_offset[0],
|
|
643
|
+
num_nonzero_groups=rhs.shape[0],
|
|
644
|
+
group_metadata=group_metadata,
|
|
645
|
+
)
|
|
646
|
+
return out
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|