tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,507 +1,114 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from enum import Enum
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
from jax.
|
|
7
|
-
from
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import torch
|
|
20
|
+
from jax.sharding import Mesh
|
|
21
|
+
from torchax.interop import jax_view, torch_view
|
|
22
|
+
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
|
23
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
8
24
|
|
|
9
|
-
from tpu_inference
|
|
10
|
-
|
|
25
|
+
from tpu_inference import envs
|
|
26
|
+
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
27
|
+
from tpu_inference.layers.common.fused_moe_gmm import fused_moe_func
|
|
28
|
+
from tpu_inference.logger import init_logger
|
|
11
29
|
|
|
12
|
-
|
|
30
|
+
logger = init_logger(__name__)
|
|
13
31
|
|
|
14
32
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
case "swigluoai":
|
|
20
|
-
return _swigluoai(x1, x2)
|
|
21
|
-
case _:
|
|
22
|
-
raise NotImplementedError(
|
|
23
|
-
f"FusedMoE does not support {activation} activation")
|
|
33
|
+
class FusedMoEBackend(Enum):
|
|
34
|
+
FUSED_MOE = "fused_moe"
|
|
35
|
+
GMM_EP = "gmm_ep"
|
|
36
|
+
GMM_TP = "gmm_tp"
|
|
24
37
|
|
|
25
38
|
|
|
26
|
-
def
|
|
27
|
-
|
|
28
|
-
|
|
39
|
+
def select_moe_backend(moe: FusedMoEConfig):
|
|
40
|
+
if envs.USE_MOE_EP_KERNEL:
|
|
41
|
+
if moe.use_ep:
|
|
42
|
+
return FusedMoEBackend.FUSED_MOE
|
|
43
|
+
logger.warning_once(
|
|
44
|
+
"USE_MOE_EP_KERNEL=1 but expert parallelism is not "
|
|
45
|
+
"enabled. Falling back to gmm implementation.")
|
|
29
46
|
|
|
30
|
-
|
|
47
|
+
if moe.use_ep:
|
|
48
|
+
return FusedMoEBackend.GMM_EP
|
|
31
49
|
|
|
32
|
-
|
|
50
|
+
# Use default implementation.
|
|
51
|
+
return FusedMoEBackend.GMM_TP
|
|
33
52
|
|
|
34
53
|
|
|
35
|
-
def
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
If `x` is less than or equal to 128, returns 128.
|
|
41
|
-
If `x` is less than `limit`, returns the smallest multiple of 128 greater
|
|
42
|
-
than or equal to `x`.
|
|
43
|
-
If `x` is greater than or equal to `limit`, searches for the largest
|
|
44
|
-
multiple of 128 less than or equal to `limit` (down to 512) that divides `x`
|
|
45
|
-
evenly, and returns it.
|
|
46
|
-
If no such candidate is found, returns `limit`.
|
|
47
|
-
|
|
48
|
-
Args:
|
|
49
|
-
x (int): The integer to round up.
|
|
50
|
-
limit (int): The upper bound (must be a multiple of 128).
|
|
51
|
-
|
|
52
|
-
Returns:
|
|
53
|
-
int: The rounded value according to the rules above.
|
|
54
|
-
|
|
55
|
-
Raises:
|
|
56
|
-
AssertionError: If `limit` is less than 128 or not a multiple of 128.
|
|
57
|
-
"""
|
|
58
|
-
assert limit >= 128 and limit % 128 == 0
|
|
59
|
-
if x <= 128:
|
|
60
|
-
return 128
|
|
61
|
-
if x < limit:
|
|
62
|
-
return (x + 127) // 128 * 128
|
|
63
|
-
for candidate in range(limit, 511, -128):
|
|
64
|
-
if x % candidate == 0:
|
|
65
|
-
return candidate
|
|
66
|
-
return limit
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def _get_tiling_size_for_gmm_kernel(m: int, k: int, n: int,
|
|
70
|
-
g: int) -> tuple[int, int, int]:
|
|
71
|
-
"""
|
|
72
|
-
Calculate optimal tiling sizes for a GMM kernel in a Mixture of Experts
|
|
73
|
-
(MoE) setting.
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
m (int): The total number of tokens.
|
|
77
|
-
n (int): The output feature dimension.
|
|
78
|
-
k (int): The input feature dimension.
|
|
79
|
-
g (int): The number of experts.
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
tuple[int, int, int]: A tuple (tm, tk, tn)
|
|
83
|
-
"""
|
|
84
|
-
|
|
85
|
-
# TODO(Chengji): increase the upper limit tiling size of m when we can set
|
|
86
|
-
# the vmem size to be used for gmm kernel.
|
|
87
|
-
# NOTE: In average each expert has m // g tokens, but as it might be
|
|
88
|
-
# unbalanced, here we doubled the token size when choosing tiling size of m.
|
|
89
|
-
# 2m//g can be either greater or less than 512. If there are 32 tokens and
|
|
90
|
-
# topk=2, m=topk * num_tokens=64, in this case, 2*m//g will be less than
|
|
91
|
-
# 512.
|
|
92
|
-
tm = _round_up_to_multiple_of_128_within_limit(2 * m // g, 512)
|
|
93
|
-
tm = min(tm, m) # there's a requirement that m % tm == 0
|
|
94
|
-
# k/n correspond to n_input_features/n_output_features in the matmul so they
|
|
95
|
-
# are normally greater than 2048, unless the num shards is large.
|
|
96
|
-
tk = _round_up_to_multiple_of_128_within_limit(k, 2048)
|
|
97
|
-
tn = _round_up_to_multiple_of_128_within_limit(n, 2048)
|
|
98
|
-
return tm, tk, tn
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def tensor_sharded_gmm_merged_column_parallel(
|
|
102
|
-
lhs: jax.Array,
|
|
103
|
-
rhs: jax.Array,
|
|
104
|
-
rhs_bias: jax.Array | None,
|
|
105
|
-
group_sizes: jax.Array,
|
|
106
|
-
transpose_rhs: bool,
|
|
107
|
-
mesh: Mesh,
|
|
108
|
-
intermediate_size: int,
|
|
109
|
-
) -> jax.Array:
|
|
110
|
-
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
111
|
-
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
112
|
-
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
113
|
-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
114
|
-
|
|
115
|
-
_gmm = functools.partial(
|
|
116
|
-
gmm,
|
|
117
|
-
preferred_element_type=lhs.dtype,
|
|
118
|
-
tiling=(tm, tk, tn),
|
|
119
|
-
transpose_rhs=transpose_rhs,
|
|
120
|
-
group_offset=jnp.array(0),
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
gmm_result = shard_map(
|
|
124
|
-
_gmm,
|
|
125
|
-
mesh=mesh,
|
|
126
|
-
in_specs=(P(), P(None, "model", None), P()),
|
|
127
|
-
out_specs=(P(None, "model")),
|
|
128
|
-
check_rep=False,
|
|
129
|
-
)(lhs, rhs, group_sizes)
|
|
130
|
-
|
|
131
|
-
if rhs_bias is not None:
|
|
132
|
-
rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
|
|
133
|
-
gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
|
|
134
|
-
|
|
135
|
-
n_shards = mesh.shape["model"]
|
|
136
|
-
output_sizes = [intermediate_size, intermediate_size]
|
|
137
|
-
|
|
138
|
-
return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
|
|
139
|
-
n_shards)
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def tensor_sharded_gmm_row_parallel(
|
|
143
|
-
lhs: jax.Array,
|
|
144
|
-
rhs: jax.Array,
|
|
145
|
-
rhs_bias: jax.Array | None,
|
|
146
|
-
group_sizes: jax.Array,
|
|
147
|
-
transpose_rhs: bool,
|
|
148
|
-
mesh: Mesh,
|
|
149
|
-
) -> jax.Array:
|
|
150
|
-
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
151
|
-
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
152
|
-
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
153
|
-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
154
|
-
|
|
155
|
-
_gmm = functools.partial(
|
|
156
|
-
gmm,
|
|
157
|
-
preferred_element_type=lhs.dtype,
|
|
158
|
-
tiling=(tm, tk, tn),
|
|
159
|
-
transpose_rhs=transpose_rhs,
|
|
160
|
-
group_offset=jnp.array(0),
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
def _gmm_all_reduce(lhs, rhs, group_sizes):
|
|
164
|
-
r = _gmm(lhs, rhs, group_sizes)
|
|
165
|
-
return jax.lax.psum(r, axis_name="model")
|
|
166
|
-
|
|
167
|
-
gmm_result = shard_map(
|
|
168
|
-
_gmm_all_reduce,
|
|
169
|
-
mesh=mesh,
|
|
170
|
-
in_specs=(P(None, "model"), P(None, None, "model"), P()),
|
|
171
|
-
out_specs=(P()),
|
|
172
|
-
check_rep=False,
|
|
173
|
-
)(lhs, rhs, group_sizes)
|
|
174
|
-
|
|
175
|
-
if rhs_bias is not None:
|
|
176
|
-
rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
|
|
177
|
-
gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
|
|
178
|
-
|
|
179
|
-
return gmm_result
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
def expert_sharded_gmm(
|
|
183
|
-
lhs: jax.Array,
|
|
184
|
-
rhs: jax.Array,
|
|
185
|
-
group_sizes: jax.Array,
|
|
186
|
-
transpose_rhs: bool,
|
|
54
|
+
def fused_moe_apply(
|
|
55
|
+
layer: torch.nn.Module,
|
|
56
|
+
x: torch.Tensor,
|
|
57
|
+
router_logits: torch.Tensor,
|
|
58
|
+
moe_backend: FusedMoEBackend,
|
|
187
59
|
mesh: Mesh,
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
# For i-th shard, it is responsible groups (AKA experts) from
|
|
245
|
-
# i*num_experts_per_shard to (i+1)*num_experts_per_shard We sum them up to
|
|
246
|
-
# get total rows in that shard, and that is the size for shard to send to
|
|
247
|
-
# its peers. This is also the number of non-zero rows from the gmm results.
|
|
248
|
-
# In the working example, send_sizes would be [3, 2, 5, 4]
|
|
249
|
-
send_sizes = jnp.array([
|
|
250
|
-
group_sizes[i * num_experts_per_shard:(i + 1) *
|
|
251
|
-
num_experts_per_shard].sum() for i in range(ep_size)
|
|
252
|
-
])
|
|
253
|
-
# In the working example, input_offsets would be [0, 3, 5, 10]
|
|
254
|
-
input_offsets = jnp.concatenate((jnp.array([0]), send_sizes.cumsum()[:-1]))
|
|
255
|
-
output_offsets = input_offsets
|
|
256
|
-
recv_sizes = send_sizes
|
|
257
|
-
|
|
258
|
-
input_offsets = jax.lax.with_sharding_constraint(
|
|
259
|
-
input_offsets, NamedSharding(mesh, P("model")))
|
|
260
|
-
send_sizes = jax.lax.with_sharding_constraint(
|
|
261
|
-
send_sizes, NamedSharding(mesh, P("model")))
|
|
262
|
-
output_offsets = jax.lax.with_sharding_constraint(
|
|
263
|
-
output_offsets, NamedSharding(mesh, P("model")))
|
|
264
|
-
|
|
265
|
-
def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
|
|
266
|
-
recv_sizes):
|
|
267
|
-
output = jnp.zeros_like(operand)
|
|
268
|
-
|
|
269
|
-
# input_offsets, send_sizes and output_offsets are sharded and there is
|
|
270
|
-
# only 1 elemnt in each shard, we are taking the 0-th element from them
|
|
271
|
-
# just so that jnp.repeat generates the arrays with correct shape.
|
|
272
|
-
input_offsets_of_shard = jnp.repeat(input_offsets[0], ep_size)
|
|
273
|
-
send_sizes_of_shard = jnp.repeat(send_sizes[0], ep_size)
|
|
274
|
-
output_offsets_of_shard = jnp.repeat(output_offsets[0], ep_size)
|
|
275
|
-
|
|
276
|
-
# recv_sizes is replicated across shards, because all the shards receive
|
|
277
|
-
# the same data and write to the output in the same way (same
|
|
278
|
-
# output_offsets and same recv_sizes) and thus generates replicated
|
|
279
|
-
# output.
|
|
280
|
-
recv_sizes_of_shard = recv_sizes
|
|
281
|
-
|
|
282
|
-
# In the working example, for each shard, the values of the offsets and
|
|
283
|
-
# sizes would be:
|
|
284
|
-
# shard-0 shard-1 shard-2 shard-3
|
|
285
|
-
# input_offsets_of_shard [0, 0, 0, 0] [3, 3, 3, 3] [5, 5, 5, 5] [10,10,10,10]
|
|
286
|
-
# send_sizes_of_shard [3, 3, 3, 3] [2, 2, 2, 2] [5, 5, 5, 5] [4, 4, 4, 4 ]
|
|
287
|
-
# output_offsets_of_shard [0, 0, 0, 0] [0, 0, 0, 0] [0, 0, 0, 0] [10,10,10,10]
|
|
288
|
-
# recv_sizes_of_shard [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4]
|
|
289
|
-
return jax.lax.ragged_all_to_all(operand,
|
|
290
|
-
output,
|
|
291
|
-
input_offsets_of_shard,
|
|
292
|
-
send_sizes_of_shard,
|
|
293
|
-
output_offsets_of_shard,
|
|
294
|
-
recv_sizes_of_shard,
|
|
295
|
-
axis_name="model")
|
|
296
|
-
|
|
297
|
-
# Use ragged_all_to_all to send the result from gmm for each expert to all
|
|
298
|
-
# the shards. In the working example, the result would be:
|
|
299
|
-
# A, A, A, A A, A, A, A A, A, A, A A, A, A, A
|
|
300
|
-
# A, A, A, A A, A, A, A A, A, A, A A, A, A, A
|
|
301
|
-
# A, A, A, A A, A, A, A A, A, A, A A, A, A, A
|
|
302
|
-
# B, B, B, B B, B, B, B B, B, B, B B, B, B, B
|
|
303
|
-
# B, B, B, B B, B, B, B B, B, B, B B, B, B, B
|
|
304
|
-
# C, C, C, C C, C, C, C C, C, C, C C, C, C, C
|
|
305
|
-
# C, C, C, C C, C, C, C C, C, C, C C, C, C, C
|
|
306
|
-
# C, C, C, C C, C, C, C C, C, C, C C, C, C, C
|
|
307
|
-
# C, C, C, C C, C, C, C C, C, C, C C, C, C, C
|
|
308
|
-
# C, C, C, C C, C, C, C C, C, C, C C, C, C, C
|
|
309
|
-
# D, D, D, D D, D, D, D D, D, D, D D, D, D, D
|
|
310
|
-
# D, D, D, D D, D, D, D D, D, D, D D, D, D, D
|
|
311
|
-
# D, D, D, D D, D, D, D D, D, D, D D, D, D, D
|
|
312
|
-
# D, D, D, D D, D, D, D D, D, D, D D, D, D, D
|
|
313
|
-
# shard-0 shard-1 shard-2 shard-3
|
|
314
|
-
return shard_map(
|
|
315
|
-
_ragged_all_to_all,
|
|
316
|
-
mesh=mesh,
|
|
317
|
-
in_specs=(P("model", None), P("model"), P("model"), P("model"), P()),
|
|
318
|
-
out_specs=(P()),
|
|
319
|
-
check_rep=False,
|
|
320
|
-
)(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
def fused_moe_func(
|
|
324
|
-
hidden_states: jax.Array,
|
|
325
|
-
w1: jax.Array,
|
|
326
|
-
w2: jax.Array,
|
|
327
|
-
w1_bias: jax.Array | None,
|
|
328
|
-
w2_bias: jax.Array | None,
|
|
329
|
-
gating_output: jax.Array,
|
|
330
|
-
topk: int,
|
|
331
|
-
global_num_experts: int,
|
|
332
|
-
renormalize: bool,
|
|
333
|
-
reduce_results: bool,
|
|
334
|
-
mesh: Mesh,
|
|
335
|
-
use_ep: bool,
|
|
336
|
-
activation: str,
|
|
337
|
-
):
|
|
338
|
-
"""
|
|
339
|
-
Args:
|
|
340
|
-
hidden_states: [*, hidden_size]
|
|
341
|
-
w1: [num_experts, intermediate_size * 2, hidden_size]
|
|
342
|
-
w2: [num_experts, hidden_size, intermediate_size]
|
|
343
|
-
gating_output: [*, num_experts]
|
|
344
|
-
"""
|
|
345
|
-
# adapted from https://github.com/vllm-project/vllm/blob/29fa5cac1cd731026f59084d93a822921507573c/vllm/model_executor/layers/fused_moe/moe_pallas.py#L26
|
|
346
|
-
if use_ep and (w1_bias is not None or w2_bias is not None):
|
|
347
|
-
raise NotImplementedError(
|
|
348
|
-
"Bias is not supported when using expert parallelism.")
|
|
349
|
-
orig_shape = hidden_states.shape
|
|
350
|
-
hidden_size = hidden_states.shape[-1]
|
|
351
|
-
num_tokens = hidden_states.size // hidden_size
|
|
352
|
-
assert global_num_experts == w1.shape[0]
|
|
353
|
-
ep_size = mesh.shape["model"] # only used if use_ep is True.
|
|
354
|
-
intermediate_size = w2.shape[-1]
|
|
355
|
-
dtype = hidden_states.dtype
|
|
356
|
-
assert (num_tokens * topk) % 16 == 0, (
|
|
357
|
-
"The kernel requires num_tokens * topk to be a multiple of "
|
|
358
|
-
f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
|
|
359
|
-
|
|
360
|
-
hidden_states = hidden_states.reshape(num_tokens, hidden_size)
|
|
361
|
-
gating_output = gating_output.reshape(num_tokens, global_num_experts)
|
|
362
|
-
|
|
363
|
-
topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
|
|
364
|
-
topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
|
|
365
|
-
if renormalize:
|
|
366
|
-
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
|
|
367
|
-
topk_weights = topk_weights.astype(dtype)
|
|
368
|
-
|
|
369
|
-
topk_indices_flat = topk_indices.flatten()
|
|
370
|
-
topk_argsort_indices = jnp.argsort(topk_indices_flat)
|
|
371
|
-
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
|
|
372
|
-
token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
|
|
373
|
-
token_indices_sorted = token_indices[topk_argsort_indices]
|
|
374
|
-
group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)
|
|
375
|
-
|
|
376
|
-
x = hidden_states[token_indices_sorted]
|
|
377
|
-
|
|
378
|
-
if use_ep:
|
|
379
|
-
x = expert_sharded_gmm(
|
|
380
|
-
x,
|
|
381
|
-
w1,
|
|
382
|
-
group_sizes,
|
|
383
|
-
transpose_rhs=True,
|
|
384
|
-
mesh=mesh,
|
|
385
|
-
num_experts=global_num_experts,
|
|
386
|
-
ep_size=ep_size,
|
|
387
|
-
)
|
|
388
|
-
x1, x2 = x[..., :intermediate_size], x[..., intermediate_size:]
|
|
389
|
-
else:
|
|
390
|
-
x1, x2 = tensor_sharded_gmm_merged_column_parallel(
|
|
391
|
-
x,
|
|
392
|
-
w1,
|
|
393
|
-
w1_bias,
|
|
394
|
-
group_sizes,
|
|
395
|
-
transpose_rhs=True,
|
|
396
|
-
mesh=mesh,
|
|
397
|
-
intermediate_size=intermediate_size,
|
|
398
|
-
)
|
|
399
|
-
|
|
400
|
-
x = activation_fn(activation, x1, x2)
|
|
401
|
-
|
|
402
|
-
if use_ep:
|
|
403
|
-
x = expert_sharded_gmm(
|
|
404
|
-
x,
|
|
405
|
-
w2,
|
|
406
|
-
group_sizes,
|
|
407
|
-
transpose_rhs=True,
|
|
408
|
-
mesh=mesh,
|
|
409
|
-
num_experts=global_num_experts,
|
|
410
|
-
ep_size=ep_size,
|
|
411
|
-
)
|
|
412
|
-
else:
|
|
413
|
-
x = jax.lax.with_sharding_constraint(
|
|
414
|
-
x, NamedSharding(mesh, P(None, "model")))
|
|
415
|
-
x = tensor_sharded_gmm_row_parallel(
|
|
416
|
-
x,
|
|
417
|
-
w2,
|
|
418
|
-
w2_bias,
|
|
419
|
-
group_sizes,
|
|
420
|
-
transpose_rhs=True,
|
|
421
|
-
mesh=mesh,
|
|
422
|
-
)
|
|
423
|
-
|
|
424
|
-
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
|
425
|
-
x = x * jnp.expand_dims(topk_weights, axis=-1)
|
|
426
|
-
x = x.sum(axis=-2)
|
|
427
|
-
x = x.reshape(orig_shape)
|
|
428
|
-
|
|
429
|
-
if reduce_results:
|
|
430
|
-
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
|
|
431
|
-
return x
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
@functools.partial(
|
|
435
|
-
jax.jit,
|
|
436
|
-
static_argnames=(
|
|
437
|
-
"topk",
|
|
438
|
-
"global_num_experts",
|
|
439
|
-
"renormalize",
|
|
440
|
-
"reduce_results",
|
|
441
|
-
"mesh",
|
|
442
|
-
"use_ep",
|
|
443
|
-
"activation",
|
|
444
|
-
),
|
|
445
|
-
)
|
|
446
|
-
def fused_moe_func_padded(
|
|
447
|
-
hidden_states: jax.Array,
|
|
448
|
-
w1: jax.Array,
|
|
449
|
-
w2: jax.Array,
|
|
450
|
-
w1_bias: jax.Array | None,
|
|
451
|
-
w2_bias: jax.Array | None,
|
|
452
|
-
gating_output: jax.Array,
|
|
453
|
-
topk: int,
|
|
454
|
-
global_num_experts: int,
|
|
455
|
-
renormalize: bool,
|
|
456
|
-
reduce_results: bool,
|
|
457
|
-
mesh: Mesh,
|
|
458
|
-
use_ep: bool,
|
|
459
|
-
activation: str,
|
|
460
|
-
):
|
|
461
|
-
# TODO(fanhongmin@google.com): Once the jax runner pads the input, we no longer need this.
|
|
462
|
-
hidden_size = hidden_states.shape[-1]
|
|
463
|
-
num_tokens = hidden_states.size // hidden_size
|
|
464
|
-
if num_tokens * topk < 16:
|
|
465
|
-
assert 16 % (num_tokens *
|
|
466
|
-
topk) == 0, f"Cannot pad to 16: {num_tokens=}, {topk=}"
|
|
467
|
-
n_repeats = 16 // (num_tokens * topk)
|
|
468
|
-
|
|
469
|
-
reps = (n_repeats, ) + (1, ) * (hidden_states.ndim - 1)
|
|
470
|
-
expanded_hidden_states = jnp.tile(hidden_states, reps)
|
|
471
|
-
|
|
472
|
-
reps = (n_repeats, ) + (1, ) * (gating_output.ndim - 1)
|
|
473
|
-
expanded_gating_output = jnp.tile(gating_output, reps)
|
|
474
|
-
|
|
475
|
-
expanded_x = fused_moe_func(
|
|
476
|
-
expanded_hidden_states,
|
|
477
|
-
w1,
|
|
478
|
-
w2,
|
|
479
|
-
w1_bias,
|
|
480
|
-
w2_bias,
|
|
481
|
-
expanded_gating_output,
|
|
482
|
-
topk,
|
|
483
|
-
global_num_experts,
|
|
484
|
-
renormalize,
|
|
485
|
-
reduce_results,
|
|
486
|
-
mesh,
|
|
487
|
-
use_ep,
|
|
488
|
-
activation,
|
|
489
|
-
)
|
|
490
|
-
x = expanded_x[:hidden_states.shape[0]]
|
|
491
|
-
return x
|
|
492
|
-
else:
|
|
493
|
-
return fused_moe_func(
|
|
494
|
-
hidden_states,
|
|
495
|
-
w1,
|
|
496
|
-
w2,
|
|
497
|
-
w1_bias,
|
|
498
|
-
w2_bias,
|
|
499
|
-
gating_output,
|
|
500
|
-
topk,
|
|
501
|
-
global_num_experts,
|
|
502
|
-
renormalize,
|
|
503
|
-
reduce_results,
|
|
504
|
-
mesh,
|
|
505
|
-
use_ep,
|
|
506
|
-
activation,
|
|
507
|
-
)
|
|
60
|
+
extra_backend_kwargs: dict,
|
|
61
|
+
) -> torch.Tensor:
|
|
62
|
+
assert isinstance(layer, FusedMoE)
|
|
63
|
+
if layer.scoring_func != "softmax":
|
|
64
|
+
raise NotImplementedError("Only softmax is supported for scoring_func")
|
|
65
|
+
|
|
66
|
+
x = jax_view(x)
|
|
67
|
+
gating_output = jax_view(router_logits)
|
|
68
|
+
|
|
69
|
+
w13_weight = jax_view(layer.w13_weight)
|
|
70
|
+
w13_weight_scale = jax_view(getattr(layer, "w13_weight_scale", None))
|
|
71
|
+
w13_bias = jax_view(getattr(layer, "w13_bias", None))
|
|
72
|
+
w2_weight = jax_view(layer.w2_weight)
|
|
73
|
+
w2_weight_scale = jax_view(getattr(layer, "w2_weight_scale", None))
|
|
74
|
+
w2_bias = jax_view(getattr(layer, "w2_bias", None))
|
|
75
|
+
|
|
76
|
+
with jax.named_scope(layer._get_name()):
|
|
77
|
+
match moe_backend:
|
|
78
|
+
case FusedMoEBackend.FUSED_MOE:
|
|
79
|
+
actual_hidden_size = x.shape[-1]
|
|
80
|
+
padding_size = w13_weight.shape[-2] - actual_hidden_size
|
|
81
|
+
x = jnp.pad(x, ((0, 0), (0, padding_size)))
|
|
82
|
+
output = fused_ep_moe(
|
|
83
|
+
mesh=mesh,
|
|
84
|
+
tokens=x,
|
|
85
|
+
w1=w13_weight,
|
|
86
|
+
w2=w2_weight,
|
|
87
|
+
w1_scale=w13_weight_scale,
|
|
88
|
+
w2_scale=w2_weight_scale,
|
|
89
|
+
b1=w13_bias,
|
|
90
|
+
b2=w2_bias,
|
|
91
|
+
gating_output=gating_output,
|
|
92
|
+
top_k=layer.top_k,
|
|
93
|
+
renormalize_topk_logits=layer.renormalize,
|
|
94
|
+
act_fn=layer.activation,
|
|
95
|
+
**extra_backend_kwargs,
|
|
96
|
+
)[:, :actual_hidden_size]
|
|
97
|
+
case FusedMoEBackend.GMM_EP | FusedMoEBackend.GMM_TP:
|
|
98
|
+
output = fused_moe_func(
|
|
99
|
+
hidden_states=x,
|
|
100
|
+
w1=w13_weight,
|
|
101
|
+
w2=w2_weight,
|
|
102
|
+
w1_scale=w13_weight_scale,
|
|
103
|
+
w2_scale=w2_weight_scale,
|
|
104
|
+
w1_bias=w13_bias,
|
|
105
|
+
w2_bias=w2_bias,
|
|
106
|
+
gating_output=gating_output,
|
|
107
|
+
topk=layer.top_k,
|
|
108
|
+
renormalize=layer.renormalize,
|
|
109
|
+
mesh=mesh,
|
|
110
|
+
use_ep=layer.use_ep,
|
|
111
|
+
activation=layer.activation,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return torch_view(output)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
from jax.sharding import Mesh, NamedSharding
|
|
17
|
+
from jax.sharding import PartitionSpec as P
|
|
18
|
+
|
|
19
|
+
from tpu_inference import envs
|
|
20
|
+
from tpu_inference.kernels.quantized_matmul.kernel import (
|
|
21
|
+
quantized_matmul_kernel, xla_quantized_matmul)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
|
|
25
|
+
mesh: Mesh, weight_sharding: P) -> jax.Array:
|
|
26
|
+
"""
|
|
27
|
+
Wrapper around the quantized matmul kernel.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
x: Activation.
|
|
31
|
+
w_q: Weight quantized array. [n_output_features, n_input_features]
|
|
32
|
+
w_s: Weight quantization scale. [n_output_features]
|
|
33
|
+
mesh: Mesh to shard on.
|
|
34
|
+
weight_sharding: PartitionSpec for the weight tensor.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Output of the quantized matmul.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
# NOTE (jacobplatin/kyuyeunk) there have been numeric issues (concerning) NaNs
|
|
41
|
+
# with the kernel and thus we disable it for now.
|
|
42
|
+
if envs.ENABLE_QUANTIZED_MATMUL_KERNEL:
|
|
43
|
+
out_axis, in_axis = weight_sharding
|
|
44
|
+
x_sharding = P(None, in_axis)
|
|
45
|
+
scale_sharding = P(out_axis, )
|
|
46
|
+
out_sharding = P(None, out_axis)
|
|
47
|
+
|
|
48
|
+
x = jax.lax.with_sharding_constraint(x,
|
|
49
|
+
NamedSharding(mesh, x_sharding))
|
|
50
|
+
|
|
51
|
+
def wrapper(x, w_q, w_s):
|
|
52
|
+
output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
|
|
53
|
+
if in_axis:
|
|
54
|
+
output = jax.lax.psum(output, axis_name=in_axis)
|
|
55
|
+
return output
|
|
56
|
+
|
|
57
|
+
return jax.shard_map(wrapper,
|
|
58
|
+
mesh=mesh,
|
|
59
|
+
in_specs=(x_sharding, weight_sharding,
|
|
60
|
+
scale_sharding),
|
|
61
|
+
out_specs=(out_sharding),
|
|
62
|
+
check_vma=False)(x, w_q, w_s)
|
|
63
|
+
else:
|
|
64
|
+
return xla_quantized_matmul(x, w_q, w_s)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|