tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,203 +1,266 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Union
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
5
19
|
import torch
|
|
6
|
-
|
|
20
|
+
from compressed_tensors.quantization import QuantizationArgs
|
|
7
21
|
from jax.experimental.layout import Format, Layout
|
|
8
22
|
from jax.sharding import Mesh, NamedSharding
|
|
9
23
|
from jax.sharding import PartitionSpec as P
|
|
10
24
|
from torch.nn.parameter import Parameter
|
|
11
|
-
from torchax.interop import
|
|
25
|
+
from torchax.interop import jax_view, torch_view
|
|
12
26
|
from torchax.ops.mappings import t2j
|
|
13
27
|
from vllm.logger import init_logger
|
|
14
28
|
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
|
|
15
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.
|
|
16
|
-
|
|
17
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
18
|
-
CompressedTensorsW8A8Fp8MoEMethod
|
|
19
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
|
20
|
-
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
|
29
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
|
30
|
+
CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)
|
|
21
31
|
|
|
32
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
33
|
+
from tpu_inference.layers.vllm.linear_common import \
|
|
34
|
+
reorder_concatenated_tensor_for_sharding
|
|
22
35
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
36
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
37
|
+
VllmUnquantizedFusedMoEMethod
|
|
23
38
|
|
|
24
39
|
logger = init_logger(__name__)
|
|
25
40
|
|
|
26
41
|
|
|
42
|
+
class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def get_moe_method(
|
|
46
|
+
quant_config: "VllmCompressedTensorsConfig", # type: ignore # noqa E501
|
|
47
|
+
layer: torch.nn.Module,
|
|
48
|
+
layer_name: str,
|
|
49
|
+
) -> CompressedTensorsMoEMethod:
|
|
50
|
+
assert isinstance(layer, FusedMoE)
|
|
51
|
+
|
|
52
|
+
# FusedMoE was made by combining multiple Linears so need to
|
|
53
|
+
# make sure quantization config for Linear can target it
|
|
54
|
+
quant_config._add_fused_moe_to_target_scheme_map()
|
|
55
|
+
unfused_names = [
|
|
56
|
+
layer_name + proj_name
|
|
57
|
+
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
|
|
58
|
+
]
|
|
59
|
+
# TODO: refactor this to use expert_mapping and check all layer numbers
|
|
60
|
+
all_scheme_dicts = [
|
|
61
|
+
quant_config.get_scheme_dict(layer, name) for name in unfused_names
|
|
62
|
+
]
|
|
63
|
+
scheme_dict = all_scheme_dicts.pop()
|
|
64
|
+
|
|
65
|
+
# multiple schemes found
|
|
66
|
+
if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
|
|
67
|
+
raise ValueError("All MoE projections need to have same "
|
|
68
|
+
"quantization scheme but found multiple")
|
|
69
|
+
|
|
70
|
+
if scheme_dict is None:
|
|
71
|
+
return VllmUnquantizedFusedMoEMethod(layer.moe_config,
|
|
72
|
+
quant_config.mesh)
|
|
73
|
+
|
|
74
|
+
weight_quant = scheme_dict.get("weights")
|
|
75
|
+
input_quant = scheme_dict.get("input_activations")
|
|
76
|
+
|
|
77
|
+
if quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
|
78
|
+
return VllmCompressedTensorsW8A8Fp8MoEMethod(
|
|
79
|
+
weight_quant, input_quant, layer.moe_config, quant_config.mesh)
|
|
80
|
+
else:
|
|
81
|
+
raise RuntimeError(
|
|
82
|
+
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
|
83
|
+
|
|
84
|
+
|
|
27
85
|
class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
|
|
28
86
|
JaxCommonConfig):
|
|
29
87
|
|
|
30
|
-
def __init__(
|
|
31
|
-
|
|
32
|
-
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
weight_quant: QuantizationArgs,
|
|
91
|
+
input_quant: QuantizationArgs,
|
|
92
|
+
moe: FusedMoEConfig,
|
|
93
|
+
mesh: Mesh,
|
|
94
|
+
):
|
|
95
|
+
super().__init__(weight_quant, input_quant, moe)
|
|
33
96
|
self.mesh = mesh
|
|
34
|
-
self.quant_config = quant_config
|
|
35
|
-
|
|
36
|
-
# disable GPU paths
|
|
37
|
-
self.use_marlin = False
|
|
38
|
-
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
|
|
39
|
-
self.is_fp8_w8a8_sm100 = False
|
|
40
|
-
self.use_cutlass = False
|
|
41
|
-
self.disable_expert_map = False
|
|
42
97
|
|
|
43
98
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Docstring for process_weights_after_loading
|
|
101
|
+
|
|
102
|
+
:param self: Description
|
|
103
|
+
:param layer: Description
|
|
104
|
+
:type layer: torch.nn.Module
|
|
105
|
+
|
|
106
|
+
Steps:
|
|
107
|
+
1. Read weights from layer object and convert to jax arrays
|
|
108
|
+
2. Interleave concat w13 weights
|
|
109
|
+
3. Shard weights for tp (rowwise w13, colwise w2)
|
|
110
|
+
4. Initialize Params as torch.nn.Parameter
|
|
111
|
+
a. w13_weight - float8_e4m3fn shape: (num_experts, 2 x intermediate_size, input_size)
|
|
112
|
+
b. w2_weight - float8_e4m3fn shape: (num_experts, output_size, intermediate_size)
|
|
113
|
+
c. w13_weight_scale - FP32 shape: (num_experts, 2 x intermediate_size, 1)
|
|
114
|
+
d. w2_weight_scale - FP32shape: (num_experts, output_size, 1)
|
|
115
|
+
"""
|
|
44
116
|
assert isinstance(layer, FusedMoE)
|
|
45
117
|
|
|
118
|
+
# Read weights from layer object
|
|
119
|
+
w13_weight = t2j(
|
|
120
|
+
layer.w13_weight, use_dlpack=False
|
|
121
|
+
) # float8_e4m3fn shape: (num_experts, 2 x intermediate_size, input_size)
|
|
122
|
+
w13_weight_scale = t2j(
|
|
123
|
+
layer.w13_weight_scale, use_dlpack=False
|
|
124
|
+
) # FP32 shape: (num_experts, 2 x intermediate_size, 1)
|
|
125
|
+
w2_weight = t2j(
|
|
126
|
+
layer.w2_weight, use_dlpack=False
|
|
127
|
+
) # float8_e4m3fn shape: (num_experts, output_size, intermediate_size)
|
|
128
|
+
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
|
|
129
|
+
w13_weight_scale = w13_weight_scale.astype(jnp.bfloat16)
|
|
130
|
+
w2_weight_scale = w2_weight_scale.astype(jnp.bfloat16)
|
|
46
131
|
intermediate_size = layer.w13_weight.shape[1] // 2
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
132
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
133
|
+
n_shards = self.mesh.shape["model"]
|
|
134
|
+
assert intermediate_size % n_shards == 0
|
|
135
|
+
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
136
|
+
assert w2_weight_scale.shape == (num_experts, hidden_size, 1)
|
|
137
|
+
assert w13_weight.shape == (num_experts, 2 * intermediate_size,
|
|
138
|
+
hidden_size)
|
|
139
|
+
assert w13_weight_scale.shape == (num_experts, 2 * intermediate_size,
|
|
140
|
+
1)
|
|
141
|
+
|
|
142
|
+
if not layer.use_ep:
|
|
143
|
+
# Interleave concat w13 weights
|
|
144
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
145
|
+
w13_weight,
|
|
146
|
+
split_sizes=(intermediate_size, intermediate_size),
|
|
147
|
+
dim=1,
|
|
148
|
+
n_shards=n_shards,
|
|
149
|
+
)
|
|
150
|
+
# Interleave concat w13 weight scales
|
|
151
|
+
w13_weight_scale = reorder_concatenated_tensor_for_sharding(
|
|
152
|
+
w13_weight_scale,
|
|
153
|
+
split_sizes=(intermediate_size, intermediate_size),
|
|
154
|
+
dim=1,
|
|
155
|
+
n_shards=n_shards,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# 160,5120,1 -> 160,1,5120
|
|
159
|
+
w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
|
|
160
|
+
# 160,1,5120 -> 160, 1, 1, 5120 (num_experts, num_blocks, 1, outer_dim)
|
|
161
|
+
w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
|
|
162
|
+
w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
|
|
163
|
+
w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
|
|
61
164
|
|
|
62
165
|
if layer.use_ep:
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
w2_weight_scale = jax.device_put(w2_weight_scale, format)
|
|
71
|
-
else:
|
|
72
|
-
assert intermediate_size == w2_weight.shape[-1]
|
|
73
|
-
n_shards = self.mesh.shape["model"]
|
|
74
|
-
assert intermediate_size % n_shards == 0
|
|
166
|
+
# Apply EP sharding
|
|
167
|
+
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
168
|
+
|
|
169
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
170
|
+
w13_weight, ep_sharding)
|
|
171
|
+
w2_weight = jax.lax.with_sharding_constraint(
|
|
172
|
+
w2_weight, ep_sharding)
|
|
75
173
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
# )
|
|
174
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
175
|
+
w13_weight_scale, ep_sharding)
|
|
176
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
177
|
+
w2_weight_scale, ep_sharding)
|
|
81
178
|
|
|
179
|
+
else:
|
|
180
|
+
# Shard weights for tp (rowwise w13, colwise w2)
|
|
82
181
|
w13_format = Format(
|
|
83
|
-
Layout((0, 1, 2)),
|
|
84
|
-
NamedSharding(self.mesh, P(None, "model", None))
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
182
|
+
Layout((0, 1, 2)), # expert, 2xintermed, input
|
|
183
|
+
NamedSharding(self.mesh, P(None, "model", None)),
|
|
184
|
+
) # rowwise sharding on intermed dim
|
|
185
|
+
|
|
186
|
+
w13_scale_format = Format(
|
|
187
|
+
Layout(
|
|
188
|
+
(0, 1, 2, 3)), # (num_experts, num_blocks, 1, outer_dim)
|
|
189
|
+
NamedSharding(self.mesh, P(None, None, None, "model")),
|
|
190
|
+
) # col wise GMM sharding on intermed dim
|
|
191
|
+
|
|
192
|
+
# Local shard shape: (num_experts, 2 x (intermediate_size // n_shards), input_size)
|
|
193
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
194
|
+
w13_weight, w13_format)
|
|
195
|
+
# Local shard shape: (num_experts, (intermediate_size // n_shards), 1)
|
|
196
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
197
|
+
w13_weight_scale, w13_scale_format)
|
|
198
|
+
|
|
199
|
+
# Shard weights for tp (colwise w2)
|
|
200
|
+
w2_format = Format(
|
|
201
|
+
Layout((0, 1, 2)), # expert, intermed, hidden
|
|
202
|
+
NamedSharding(self.mesh, P(None, None, "model")),
|
|
93
203
|
)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
) # replicate
|
|
204
|
+
# Local shard shape: (num_experts, hidden, (intermediate_size // n_shards))
|
|
205
|
+
# # (num_experts, num_blocks, 1, outer_dim)
|
|
206
|
+
w2_weight = jax.lax.with_sharding_constraint(w2_weight, w2_format)
|
|
98
207
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
208
|
+
w2_scale_format = Format(
|
|
209
|
+
Layout((0, 1, 2, 3)), # expert, intermed, 1
|
|
210
|
+
NamedSharding(self.mesh, P(None, None, None, None)),
|
|
211
|
+
)
|
|
212
|
+
# Local shard shape: (num_experts, intermediate_size // n_shards, 1)
|
|
213
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
214
|
+
w2_weight_scale, w2_scale_format)
|
|
215
|
+
|
|
216
|
+
w13_weight = Parameter(torch_view(w13_weight), requires_grad=False)
|
|
217
|
+
w13_weight_scale = Parameter(torch_view(w13_weight_scale),
|
|
218
|
+
requires_grad=False)
|
|
102
219
|
w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
103
220
|
w2_weight_scale = Parameter(torch_view(w2_weight_scale),
|
|
104
221
|
requires_grad=False)
|
|
105
|
-
w3_weight = Parameter(torch_view(w3_weight), requires_grad=False)
|
|
106
|
-
w3_weight_scale = Parameter(torch_view(w3_weight_scale),
|
|
107
|
-
requires_grad=False)
|
|
108
222
|
|
|
109
|
-
|
|
110
|
-
layer.
|
|
111
|
-
layer.w13_weight_scale = w1_weight_scale
|
|
223
|
+
layer.w13_weight = w13_weight
|
|
224
|
+
layer.w13_weight_scale = w13_weight_scale
|
|
112
225
|
layer.w2_weight = w2_weight
|
|
113
226
|
layer.w2_weight_scale = w2_weight_scale
|
|
114
|
-
layer.w3_weight = w3_weight
|
|
115
|
-
layer.w3_weight_scale = w3_weight_scale
|
|
116
227
|
|
|
117
228
|
def apply(
|
|
118
229
|
self,
|
|
119
230
|
layer: torch.nn.Module,
|
|
120
231
|
x: torch.Tensor,
|
|
121
232
|
router_logits: torch.Tensor,
|
|
122
|
-
top_k: int,
|
|
123
|
-
renormalize: bool,
|
|
124
|
-
use_grouped_topk: bool = False,
|
|
125
|
-
topk_group: Optional[int] = None,
|
|
126
|
-
num_expert_group: Optional[int] = None,
|
|
127
|
-
global_num_experts: int = -1,
|
|
128
|
-
expert_map: Optional[torch.Tensor] = None,
|
|
129
|
-
custom_routing_function: Optional[Callable] = None,
|
|
130
|
-
scoring_func: str = "softmax",
|
|
131
|
-
routed_scaling_factor: float = 1.0,
|
|
132
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
133
|
-
apply_router_weight_on_input: bool = False,
|
|
134
|
-
activation: str = "silu",
|
|
135
|
-
enable_eplb: bool = False,
|
|
136
|
-
expert_load_view: Optional[torch.Tensor] = None,
|
|
137
|
-
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
138
|
-
logical_replica_count: Optional[torch.Tensor] = None,
|
|
139
233
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
140
234
|
assert isinstance(layer, FusedMoE)
|
|
141
|
-
if activation != "silu":
|
|
235
|
+
if layer.activation != "silu":
|
|
142
236
|
raise NotImplementedError(
|
|
143
237
|
"Only silu is supported for activation function.")
|
|
144
|
-
if scoring_func != "softmax":
|
|
238
|
+
if layer.scoring_func != "softmax":
|
|
145
239
|
raise NotImplementedError(
|
|
146
240
|
"Only softmax is supported for scoring_func")
|
|
147
241
|
|
|
148
|
-
# import sys
|
|
149
|
-
# sys.stdin = open(0)
|
|
150
|
-
# breakpoint()
|
|
151
|
-
|
|
152
242
|
# TODO: Use MoE kernel when it supports fp8
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
#x3 = torch.einsum("ti, eoi -> teo", x, layer.w3_weight) * self.w3_weight_scale
|
|
177
|
-
x3 = call_jax(jax.lax.dot,
|
|
178
|
-
x,
|
|
179
|
-
layer.w3_weight,
|
|
180
|
-
dimension_numbers=(((1, ), (2, )), ((), ())),
|
|
181
|
-
preferred_element_type=jnp.bfloat16.dtype
|
|
182
|
-
) * layer.w3_weight_scale.squeeze(2)
|
|
183
|
-
|
|
184
|
-
#expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2_weight) * self.w2_weight_scale
|
|
185
|
-
expert_outs = call_jax(
|
|
186
|
-
jax.lax.dot,
|
|
187
|
-
x1 * x3,
|
|
188
|
-
layer.w2_weight,
|
|
189
|
-
dimension_numbers=(((2, ), (2, )), ((1, ), (0, ))),
|
|
190
|
-
preferred_element_type=jnp.bfloat16.dtype).transpose(
|
|
191
|
-
0, 1) * layer.w2_weight_scale.squeeze(2)
|
|
192
|
-
|
|
193
|
-
seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
|
|
194
|
-
expert_outs = expert_outs[seq_indexes, expert_indices]
|
|
195
|
-
|
|
196
|
-
# out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
|
|
197
|
-
out = call_jax(jax.lax.dot,
|
|
198
|
-
expert_outs,
|
|
199
|
-
expert_weights,
|
|
200
|
-
dimension_numbers=(((1, ), (1, )), ((0, ), (0, ))),
|
|
201
|
-
preferred_element_type=jnp.bfloat16.dtype)
|
|
243
|
+
x = jax_view(x)
|
|
244
|
+
w13_weight = jax_view(layer.w13_weight)
|
|
245
|
+
w2_weight = jax_view(layer.w2_weight)
|
|
246
|
+
w13_weight_scale = jax_view(layer.w13_weight_scale)
|
|
247
|
+
w2_weight_scale = jax_view(layer.w2_weight_scale)
|
|
248
|
+
gating_output = jax_view(router_logits)
|
|
249
|
+
out = torch_view(
|
|
250
|
+
fused_moe_func(
|
|
251
|
+
hidden_states=x,
|
|
252
|
+
w1=w13_weight,
|
|
253
|
+
w2=w2_weight,
|
|
254
|
+
w1_scale=w13_weight_scale,
|
|
255
|
+
w2_scale=w2_weight_scale,
|
|
256
|
+
w1_bias=None,
|
|
257
|
+
w2_bias=None,
|
|
258
|
+
gating_output=gating_output,
|
|
259
|
+
topk=layer.top_k,
|
|
260
|
+
renormalize=layer.renormalize,
|
|
261
|
+
mesh=self.mesh,
|
|
262
|
+
use_ep=layer.use_ep,
|
|
263
|
+
activation=layer.activation,
|
|
264
|
+
))
|
|
202
265
|
|
|
203
266
|
return out
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Optional
|
|
2
16
|
|
|
3
17
|
import jax
|
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Optional
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import torch
|
|
19
|
+
from jax.sharding import PartitionSpec
|
|
20
|
+
from vllm.logger import init_logger
|
|
21
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
22
|
+
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
23
|
+
from vllm.model_executor.layers.quantization import \
|
|
24
|
+
register_quantization_config
|
|
25
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
26
|
+
QuantizeMethodBase
|
|
27
|
+
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
|
28
|
+
Fp8LinearMethod)
|
|
29
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
30
|
+
is_layer_skipped
|
|
31
|
+
|
|
32
|
+
from tpu_inference.layers.common.quant_methods import FP8, get_tpu_quant_method
|
|
33
|
+
from tpu_inference.layers.vllm.quantization.common import (
|
|
34
|
+
JaxCommonConfig, JaxCommonLinearConfig)
|
|
35
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
36
|
+
VllmUnquantizedLinearMethod
|
|
37
|
+
|
|
38
|
+
P = PartitionSpec
|
|
39
|
+
logger = init_logger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_quantization_config(get_tpu_quant_method(FP8))
|
|
43
|
+
class VllmFp8Config(Fp8Config, JaxCommonConfig):
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def get_name(cls):
|
|
47
|
+
return FP8
|
|
48
|
+
|
|
49
|
+
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
50
|
+
return [torch.bfloat16]
|
|
51
|
+
|
|
52
|
+
def get_quant_method(
|
|
53
|
+
self, layer: torch.nn.Module, prefix: str
|
|
54
|
+
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
|
|
55
|
+
if isinstance(layer, LinearBase):
|
|
56
|
+
linear_config = self.get_linear_config(layer)
|
|
57
|
+
if is_layer_skipped(prefix, self.ignored_layers):
|
|
58
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
59
|
+
return VllmFp8LinearMethod(self, linear_config)
|
|
60
|
+
elif isinstance(layer, FusedMoE):
|
|
61
|
+
raise NotImplementedError(
|
|
62
|
+
"FP8 FusedMoE is currently not supported in torchax-jax")
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class VllmFp8LinearMethod(Fp8LinearMethod):
|
|
67
|
+
|
|
68
|
+
def __init__(self, quant_config: VllmFp8Config,
|
|
69
|
+
jax_config: JaxCommonLinearConfig):
|
|
70
|
+
super().__init__(quant_config)
|
|
71
|
+
self.jax_config = jax_config
|
|
72
|
+
self._configure_sharding()
|
|
73
|
+
|
|
74
|
+
def _configure_sharding(self) -> None:
|
|
75
|
+
|
|
76
|
+
raise NotImplementedError(
|
|
77
|
+
"Configure PartitionSpec for weight_sharding and scale_sharding "
|
|
78
|
+
"based on layer type (RowParallel/ColumnParallel)")
|
|
79
|
+
|
|
80
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
81
|
+
|
|
82
|
+
raise NotImplementedError(
|
|
83
|
+
"Convert layer.weight, layer.weight_scale, and optionally "
|
|
84
|
+
"layer.input_scale and layer.bias from torch tensors to JAX arrays "
|
|
85
|
+
"using torch_to_jax_param() with appropriate sharding")
|
|
86
|
+
|
|
87
|
+
def apply(self,
|
|
88
|
+
layer: torch.nn.Module,
|
|
89
|
+
x: torch.Tensor,
|
|
90
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
91
|
+
|
|
92
|
+
with jax.named_scope(layer._get_name()):
|
|
93
|
+
if self.jax_config.fuse_matmuls:
|
|
94
|
+
out = self._apply_fused(layer, x, bias)
|
|
95
|
+
else:
|
|
96
|
+
out = self._apply_split(layer, x, bias)
|
|
97
|
+
|
|
98
|
+
return out
|
|
99
|
+
|
|
100
|
+
def _apply_fused(self,
|
|
101
|
+
layer: torch.nn.Module,
|
|
102
|
+
x: torch.Tensor,
|
|
103
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
104
|
+
|
|
105
|
+
raise NotImplementedError(
|
|
106
|
+
"Implement single matmul for fused outputs: "
|
|
107
|
+
"quantize input to fp8, perform fp8 matmul with weight and scales, "
|
|
108
|
+
"dequantize output, and add bias if present")
|
|
109
|
+
|
|
110
|
+
def _apply_split(self,
|
|
111
|
+
layer: torch.nn.Module,
|
|
112
|
+
x: torch.Tensor,
|
|
113
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
114
|
+
|
|
115
|
+
raise NotImplementedError(
|
|
116
|
+
"Implement separate matmuls per output partition: "
|
|
117
|
+
"split weight/scale by output_sizes, perform fp8 matmul for each, "
|
|
118
|
+
"concatenate results, and add bias if present")
|