tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
from jax.sharding import Mesh
|
|
4
|
+
from vllm.config import VllmConfig
|
|
5
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
6
|
+
QuantizationConfig
|
|
7
|
+
|
|
8
|
+
from tpu_inference.layers.common import quant_methods
|
|
9
|
+
from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
|
|
10
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
11
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
|
|
12
|
+
VllmCompressedTensorsConfig # noqa: E501
|
|
13
|
+
from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
|
|
14
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
15
|
+
VllmUnquantizedConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_tpu_quantization_config(vllm_config: VllmConfig,
|
|
19
|
+
mesh: Mesh) -> QuantizationConfig:
|
|
20
|
+
model_config = copy.deepcopy(vllm_config.model_config)
|
|
21
|
+
# TODO(kyuyeunk): Add support for "tpu_int8".
|
|
22
|
+
method_to_config: dict[str, str] = {
|
|
23
|
+
None: VllmUnquantizedConfig,
|
|
24
|
+
quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
|
|
25
|
+
quant_methods.AWQ: VllmAWQConfig,
|
|
26
|
+
quant_methods.MXFP4: VllmMxfp4Config,
|
|
27
|
+
}
|
|
28
|
+
if model_config.quantization not in method_to_config:
|
|
29
|
+
raise NotImplementedError(
|
|
30
|
+
f"{model_config.quantization} quantization method not supported."
|
|
31
|
+
f" Supported methods are {method_to_config.keys()}")
|
|
32
|
+
quant_config = method_to_config[model_config.quantization]
|
|
33
|
+
assert issubclass(quant_config, JaxCommonConfig)
|
|
34
|
+
quant_config.set_configs(vllm_config, mesh)
|
|
35
|
+
|
|
36
|
+
model_config.quantization = quant_methods.get_tpu_quant_method(
|
|
37
|
+
quant_config.get_name())
|
|
38
|
+
return VllmConfig.get_quantization_config(model_config,
|
|
39
|
+
vllm_config.load_config)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
7
|
+
from torchax.interop import jax_view, torch_view
|
|
8
|
+
from vllm.logger import init_logger
|
|
9
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
10
|
+
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
11
|
+
from vllm.model_executor.layers.quantization import \
|
|
12
|
+
register_quantization_config
|
|
13
|
+
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
|
14
|
+
AWQLinearMethod)
|
|
15
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
16
|
+
QuantizeMethodBase
|
|
17
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
18
|
+
is_layer_skipped, unpack_quantized_values_into_int32)
|
|
19
|
+
from vllm.scalar_type import scalar_types
|
|
20
|
+
|
|
21
|
+
from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
|
|
22
|
+
from tpu_inference.layers.vllm.linear_common import (
|
|
23
|
+
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
24
|
+
from tpu_inference.layers.vllm.quantization.common import (
|
|
25
|
+
JaxCommonConfig, JaxCommonLinearConfig)
|
|
26
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
27
|
+
VllmUnquantizedLinearMethod
|
|
28
|
+
|
|
29
|
+
P = PartitionSpec
|
|
30
|
+
logger = init_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@register_quantization_config(get_tpu_quant_method(AWQ))
|
|
34
|
+
class VllmAWQConfig(AWQConfig, JaxCommonConfig):
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def get_name(cls):
|
|
38
|
+
return AWQ
|
|
39
|
+
|
|
40
|
+
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
41
|
+
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
|
|
42
|
+
# bfloat16 is signifcantly preferred over foat16. This might lead to
|
|
43
|
+
# some numeric output change.
|
|
44
|
+
return [torch.bfloat16]
|
|
45
|
+
|
|
46
|
+
def get_quant_method(
|
|
47
|
+
self, layer: torch.nn.Module, prefix: str
|
|
48
|
+
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
|
|
49
|
+
if isinstance(layer, LinearBase):
|
|
50
|
+
linear_config = self.get_linear_config(layer)
|
|
51
|
+
if is_layer_skipped(prefix, self.modules_to_not_convert):
|
|
52
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
53
|
+
return VllmAWQLinearMethod(self, linear_config)
|
|
54
|
+
elif isinstance(layer, FusedMoE):
|
|
55
|
+
raise NotImplementedError(
|
|
56
|
+
"AWQ FusedMoE is currently not supported in torchax-jax")
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class VllmAWQLinearMethod(AWQLinearMethod):
|
|
61
|
+
|
|
62
|
+
def __init__(self, quant_config: VllmAWQConfig,
|
|
63
|
+
jax_config: JaxCommonLinearConfig):
|
|
64
|
+
super().__init__(quant_config)
|
|
65
|
+
self.jax_config = jax_config
|
|
66
|
+
|
|
67
|
+
out_sharding, in_sharding = self.jax_config.weight_sharding[:]
|
|
68
|
+
self.jax_config.weight_sharding = P(in_sharding, None, out_sharding)
|
|
69
|
+
self.jax_config.scale_sharding = P(in_sharding, out_sharding)
|
|
70
|
+
|
|
71
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
72
|
+
qweight = layer.qweight
|
|
73
|
+
qweight = unpack_awq_weight(qweight, qweight.packed_dim)
|
|
74
|
+
|
|
75
|
+
group_size = self.quant_config.group_size
|
|
76
|
+
# Reshape so that each qweight[i] were quantized with same scales[i].
|
|
77
|
+
qweight = qweight.reshape((-1, group_size, layer.output_size))
|
|
78
|
+
qweight = torch_to_jax_param(qweight,
|
|
79
|
+
NamedSharding(
|
|
80
|
+
self.jax_config.mesh,
|
|
81
|
+
self.jax_config.weight_sharding),
|
|
82
|
+
self.jax_config.output_sizes,
|
|
83
|
+
self.jax_config.n_shards,
|
|
84
|
+
self.jax_config.fuse_matmuls,
|
|
85
|
+
dim=2,
|
|
86
|
+
jax_dtype=jnp.uint4)
|
|
87
|
+
delattr(layer, "qweight")
|
|
88
|
+
layer.qweight = qweight
|
|
89
|
+
|
|
90
|
+
qzeros = layer.qzeros
|
|
91
|
+
qzeros = unpack_awq_weight(qzeros, qzeros.packed_dim)
|
|
92
|
+
qzeros = torch_to_jax_param(qzeros,
|
|
93
|
+
NamedSharding(
|
|
94
|
+
self.jax_config.mesh,
|
|
95
|
+
self.jax_config.scale_sharding),
|
|
96
|
+
self.jax_config.output_sizes,
|
|
97
|
+
self.jax_config.n_shards,
|
|
98
|
+
self.jax_config.fuse_matmuls,
|
|
99
|
+
dim=1,
|
|
100
|
+
jax_dtype=jnp.uint4)
|
|
101
|
+
delattr(layer, "qzeros")
|
|
102
|
+
layer.qzeros = qzeros
|
|
103
|
+
|
|
104
|
+
scales = torch_to_jax_param(layer.scales,
|
|
105
|
+
NamedSharding(
|
|
106
|
+
self.jax_config.mesh,
|
|
107
|
+
self.jax_config.scale_sharding),
|
|
108
|
+
self.jax_config.output_sizes,
|
|
109
|
+
self.jax_config.n_shards,
|
|
110
|
+
self.jax_config.fuse_matmuls,
|
|
111
|
+
dim=1)
|
|
112
|
+
delattr(layer, "scales")
|
|
113
|
+
layer.scales = scales
|
|
114
|
+
|
|
115
|
+
if layer.bias is not None and not layer.skip_bias_add:
|
|
116
|
+
if layer.return_bias:
|
|
117
|
+
logger.warning_once("Bias might return incorrect value.")
|
|
118
|
+
|
|
119
|
+
bias = torch_to_jax_param(
|
|
120
|
+
layer.bias,
|
|
121
|
+
NamedSharding(self.jax_config.mesh,
|
|
122
|
+
self.jax_config.bias_sharding),
|
|
123
|
+
self.jax_config.output_sizes,
|
|
124
|
+
self.jax_config.n_shards,
|
|
125
|
+
self.jax_config.fuse_matmuls,
|
|
126
|
+
)
|
|
127
|
+
delattr(layer, "bias")
|
|
128
|
+
layer.bias = bias
|
|
129
|
+
|
|
130
|
+
def apply(self,
|
|
131
|
+
layer: torch.nn.Module,
|
|
132
|
+
x: torch.Tensor,
|
|
133
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
134
|
+
|
|
135
|
+
with jax.named_scope(layer._get_name()):
|
|
136
|
+
if self.jax_config.fuse_matmuls:
|
|
137
|
+
out = self._apply_fused(layer, x, bias)
|
|
138
|
+
else:
|
|
139
|
+
out = self._apply_split(layer, x, bias)
|
|
140
|
+
|
|
141
|
+
return out
|
|
142
|
+
|
|
143
|
+
def _apply_fused(self,
|
|
144
|
+
layer: torch.nn.Module,
|
|
145
|
+
x: torch.Tensor,
|
|
146
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
147
|
+
x_jax = jax_view(x)
|
|
148
|
+
|
|
149
|
+
qweight = jax_view(layer.qweight)
|
|
150
|
+
qzeros = jnp.expand_dims(jax_view(layer.qzeros), 1)
|
|
151
|
+
scales = jnp.expand_dims(jax_view(layer.scales), 1)
|
|
152
|
+
|
|
153
|
+
qweight = qweight.astype(jnp.int8)
|
|
154
|
+
qzeros = qzeros.astype(jnp.int8)
|
|
155
|
+
|
|
156
|
+
weight = (qweight - qzeros) * scales
|
|
157
|
+
weight = weight.reshape((-1, weight.shape[-1]))
|
|
158
|
+
outs = jnp.einsum("bd,df->bf", x_jax, weight)
|
|
159
|
+
|
|
160
|
+
if bias is not None and not layer.skip_bias_add:
|
|
161
|
+
outs += bias.jax()
|
|
162
|
+
|
|
163
|
+
outs = slice_sharded_tensor_for_concatenation(
|
|
164
|
+
outs, self.jax_config.output_sizes, self.jax_config.n_shards)
|
|
165
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
166
|
+
return torch_view(out)
|
|
167
|
+
|
|
168
|
+
def _apply_split(self,
|
|
169
|
+
layer: torch.nn.Module,
|
|
170
|
+
x: torch.Tensor,
|
|
171
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
172
|
+
assert isinstance(layer.qweight, torch.nn.ParameterList)
|
|
173
|
+
|
|
174
|
+
x_jax = jax_view(x)
|
|
175
|
+
params = zip(layer.qweight, layer.qzeros, layer.scales)
|
|
176
|
+
outs = []
|
|
177
|
+
for i, (qweight, qzeros, scales) in enumerate(params):
|
|
178
|
+
qweight = jax_view(qweight)
|
|
179
|
+
scales = jnp.expand_dims(jax_view(scales), 1)
|
|
180
|
+
qzeros = jnp.expand_dims(jax_view(qzeros), 1)
|
|
181
|
+
|
|
182
|
+
qweight = qweight.astype(jnp.int8)
|
|
183
|
+
qzeros = qzeros.astype(jnp.int8)
|
|
184
|
+
|
|
185
|
+
weight = (qweight - qzeros) * scales
|
|
186
|
+
weight = weight.reshape((-1, weight.shape[-1]))
|
|
187
|
+
out = jnp.einsum("bd,df->bf", x_jax, weight)
|
|
188
|
+
|
|
189
|
+
if bias is not None and not layer.skip_bias_add:
|
|
190
|
+
out += jax_view(bias[i])
|
|
191
|
+
|
|
192
|
+
outs.append(out)
|
|
193
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
194
|
+
return torch_view(out)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def unpack_awq_weight(weight: torch.Tensor, packed_dim: int):
|
|
198
|
+
weight = unpack_quantized_values_into_int32(weight, scalar_types.uint4,
|
|
199
|
+
packed_dim)
|
|
200
|
+
|
|
201
|
+
# AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
|
|
202
|
+
# Following list maps the order used by AWQ into an ascending order.
|
|
203
|
+
reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
|
|
204
|
+
|
|
205
|
+
orig_shape = weight.shape
|
|
206
|
+
weight = weight.reshape(orig_shape[:-1] + (-1, 8))
|
|
207
|
+
return weight[..., reverse_awq_order].reshape(orig_shape)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import torchax
|
|
2
|
+
from jax.sharding import Mesh, PartitionSpec
|
|
3
|
+
from vllm.config import VllmConfig
|
|
4
|
+
from vllm.logger import init_logger
|
|
5
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEConfig
|
|
6
|
+
# yapf: disable
|
|
7
|
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
8
|
+
LinearBase,
|
|
9
|
+
MergedColumnParallelLinear,
|
|
10
|
+
QKVParallelLinear,
|
|
11
|
+
ReplicatedLinear,
|
|
12
|
+
RowParallelLinear)
|
|
13
|
+
|
|
14
|
+
from tpu_inference.layers.vllm.linear_common import \
|
|
15
|
+
get_model_matmul_fusion_assignment
|
|
16
|
+
from tpu_inference.utils import TPU_SECOND_LAST_MINOR
|
|
17
|
+
|
|
18
|
+
# yapf: enable
|
|
19
|
+
|
|
20
|
+
P = PartitionSpec
|
|
21
|
+
|
|
22
|
+
logger = init_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class JaxCommonLinearConfig:
|
|
26
|
+
|
|
27
|
+
def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
|
|
28
|
+
assert isinstance(layer, LinearBase)
|
|
29
|
+
|
|
30
|
+
self.mesh = mesh
|
|
31
|
+
self.output_sizes = [layer.output_size]
|
|
32
|
+
self.weight_sharding = P(None, None)
|
|
33
|
+
self.fuse_matmuls = True
|
|
34
|
+
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
|
35
|
+
self.input_sharding = None
|
|
36
|
+
self.output_sharding = None
|
|
37
|
+
|
|
38
|
+
if isinstance(layer, RowParallelLinear):
|
|
39
|
+
self.weight_sharding = P(None, "model")
|
|
40
|
+
if self.enable_sequence_parallelism:
|
|
41
|
+
self.output_sharding = P("model", None)
|
|
42
|
+
elif isinstance(layer, ColumnParallelLinear):
|
|
43
|
+
self.weight_sharding = P("model", None)
|
|
44
|
+
if self.enable_sequence_parallelism:
|
|
45
|
+
self.input_sharding = P("model", None)
|
|
46
|
+
|
|
47
|
+
if isinstance(layer, MergedColumnParallelLinear) or isinstance(
|
|
48
|
+
layer, QKVParallelLinear):
|
|
49
|
+
self.output_sizes = layer.output_sizes
|
|
50
|
+
|
|
51
|
+
self.fuse_matmuls = get_model_matmul_fusion_assignment(
|
|
52
|
+
vllm_config.model_config.model,
|
|
53
|
+
vllm_config.scheduler_config.max_num_batched_tokens,
|
|
54
|
+
vllm_config.parallel_config.tensor_parallel_size,
|
|
55
|
+
layer._get_name())
|
|
56
|
+
elif isinstance(layer, ReplicatedLinear):
|
|
57
|
+
self.weight_sharding = P(None, None)
|
|
58
|
+
else:
|
|
59
|
+
logger.warning(
|
|
60
|
+
"Unsupported linear layer type of %s. Can potentially yield "
|
|
61
|
+
" bad performance.", type(layer))
|
|
62
|
+
|
|
63
|
+
self.bias_sharding = P(self.weight_sharding[0])
|
|
64
|
+
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
65
|
+
|
|
66
|
+
def get_input_sharding(self, x: torchax.tensor.Tensor):
|
|
67
|
+
if self.enable_sequence_parallelism:
|
|
68
|
+
token_num = x.shape[0]
|
|
69
|
+
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
70
|
+
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
|
|
71
|
+
return self.input_sharding
|
|
72
|
+
else:
|
|
73
|
+
return None
|
|
74
|
+
return self.input_sharding
|
|
75
|
+
|
|
76
|
+
def get_output_sharding(self, x: torchax.tensor.Tensor):
|
|
77
|
+
if self.enable_sequence_parallelism:
|
|
78
|
+
token_num = x.shape[0]
|
|
79
|
+
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
80
|
+
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
|
|
81
|
+
return self.output_sharding
|
|
82
|
+
else:
|
|
83
|
+
return None
|
|
84
|
+
return self.output_sharding
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class JaxCommonConfig:
|
|
88
|
+
vllm_config: VllmConfig
|
|
89
|
+
mesh: Mesh
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def set_configs(cls, vllm_config: VllmConfig, mesh: Mesh):
|
|
93
|
+
cls.vllm_config = vllm_config
|
|
94
|
+
cls.mesh = mesh
|
|
95
|
+
|
|
96
|
+
def get_linear_config(self, layer: LinearBase) -> JaxCommonLinearConfig:
|
|
97
|
+
assert isinstance(layer, LinearBase)
|
|
98
|
+
return JaxCommonLinearConfig(self.vllm_config, self.mesh, layer)
|
|
99
|
+
|
|
100
|
+
def get_moe_config(self, layer: FusedMoE) -> FusedMoEConfig:
|
|
101
|
+
assert isinstance(layer, FusedMoE)
|
|
102
|
+
moe_config = layer.moe_config
|
|
103
|
+
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
|
104
|
+
moe_config.moe_parallel_config.use_ep = use_ep
|
|
105
|
+
return moe_config
|
|
File without changes
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from jax.sharding import PartitionSpec
|
|
5
|
+
from vllm.attention.layer import Attention
|
|
6
|
+
from vllm.logger import init_logger
|
|
7
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
8
|
+
from vllm.model_executor.layers.linear import LinearBase
|
|
9
|
+
from vllm.model_executor.layers.quantization import \
|
|
10
|
+
register_quantization_config
|
|
11
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
12
|
+
QuantizeMethodBase # noqa: E501
|
|
13
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
|
14
|
+
CompressedTensorsConfig, CompressedTensorsKVCacheMethod,
|
|
15
|
+
CompressedTensorsLinearMethod, CompressedTensorsScheme)
|
|
16
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
17
|
+
find_matched_target, should_ignore_layer)
|
|
18
|
+
|
|
19
|
+
from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
|
|
20
|
+
get_tpu_quant_method)
|
|
21
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
22
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
23
|
+
VllmCompressedTensorsW8A8Fp8MoEMethod
|
|
24
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
|
|
25
|
+
VllmCompressedTensorsW8A8Fp8
|
|
26
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
27
|
+
VllmCompressedTensorsW8A8Int8
|
|
28
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
29
|
+
VllmUnquantizedConfig
|
|
30
|
+
|
|
31
|
+
P = PartitionSpec
|
|
32
|
+
logger = init_logger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
|
|
36
|
+
class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def get_name(cls) -> str:
|
|
40
|
+
return COMPRESSED_TENSORS
|
|
41
|
+
|
|
42
|
+
def get_scheme(self,
|
|
43
|
+
layer: torch.nn.Module,
|
|
44
|
+
layer_name: Optional[str] = None
|
|
45
|
+
) -> Optional["CompressedTensorsScheme"]:
|
|
46
|
+
"""
|
|
47
|
+
compressed-tensors supports non uniform in the following way:
|
|
48
|
+
|
|
49
|
+
targets of config_groups: There can be N config_groups which each
|
|
50
|
+
have a quantization scheme. Each config_group has a list of targets
|
|
51
|
+
which can be a full layer_name, a regex for a layer_name, or
|
|
52
|
+
an nn.Module name.
|
|
53
|
+
|
|
54
|
+
Detect whether a layer_name is found in any target and
|
|
55
|
+
use the quantization scheme corresponding to the matched target
|
|
56
|
+
to select the CompressedTensorsScheme used for inference.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
# Will be empty for models with only sparsity
|
|
60
|
+
weight_quant = input_quant = None
|
|
61
|
+
if self.target_scheme_map:
|
|
62
|
+
matched_target = find_matched_target(
|
|
63
|
+
layer_name=layer_name,
|
|
64
|
+
module=layer,
|
|
65
|
+
targets=self.target_scheme_map.keys(),
|
|
66
|
+
fused_mapping=self.packed_modules_mapping,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
scheme_dict = self.target_scheme_map[matched_target]
|
|
70
|
+
weight_quant = scheme_dict.get("weights")
|
|
71
|
+
input_quant = scheme_dict.get("input_activations")
|
|
72
|
+
|
|
73
|
+
if weight_quant is None:
|
|
74
|
+
logger.warning_once("Acceleration for non-quantized schemes is "
|
|
75
|
+
"not supported by Compressed Tensors. "
|
|
76
|
+
"Falling back to UnquantizedLinearMethod")
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
# TODO(kyuyeunk): Add support for different act_quant_format
|
|
80
|
+
|
|
81
|
+
linear_config = self.get_linear_config(layer)
|
|
82
|
+
if self._is_fp8_w8a8(weight_quant, input_quant):
|
|
83
|
+
is_static_input_scheme = input_quant and not input_quant.dynamic
|
|
84
|
+
return VllmCompressedTensorsW8A8Fp8(
|
|
85
|
+
weight_quant=weight_quant,
|
|
86
|
+
is_static_input_scheme=is_static_input_scheme,
|
|
87
|
+
jax_config=linear_config,
|
|
88
|
+
)
|
|
89
|
+
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
|
90
|
+
return VllmCompressedTensorsW8A8Int8(
|
|
91
|
+
strategy=weight_quant.strategy,
|
|
92
|
+
is_static_input_scheme=False,
|
|
93
|
+
input_symmetric=input_quant.symmetric,
|
|
94
|
+
jax_config=linear_config,
|
|
95
|
+
)
|
|
96
|
+
raise NotImplementedError(
|
|
97
|
+
"No compressed-tensors compatible scheme was found.")
|
|
98
|
+
|
|
99
|
+
def get_quant_method(
|
|
100
|
+
self,
|
|
101
|
+
layer: torch.nn.Module,
|
|
102
|
+
prefix: str,
|
|
103
|
+
) -> Optional[QuantizeMethodBase]:
|
|
104
|
+
if should_ignore_layer(prefix,
|
|
105
|
+
ignore=self.ignore,
|
|
106
|
+
fused_mapping=self.packed_modules_mapping):
|
|
107
|
+
return VllmUnquantizedConfig.get_quant_method(self, layer, prefix)
|
|
108
|
+
if isinstance(layer, LinearBase):
|
|
109
|
+
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
|
110
|
+
if scheme is None:
|
|
111
|
+
return VllmUnquantizedConfig.get_quant_method(
|
|
112
|
+
self, layer, prefix)
|
|
113
|
+
layer.scheme = scheme
|
|
114
|
+
return CompressedTensorsLinearMethod(self)
|
|
115
|
+
if isinstance(layer, FusedMoE):
|
|
116
|
+
return VllmCompressedTensorsW8A8Fp8MoEMethod(
|
|
117
|
+
self, layer.quant_config, self.mesh)
|
|
118
|
+
if isinstance(layer, Attention):
|
|
119
|
+
return CompressedTensorsKVCacheMethod(self)
|
|
120
|
+
return None
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
from typing import Callable, Optional, Union
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from jax.experimental.layout import Format, Layout
|
|
8
|
+
from jax.sharding import Mesh, NamedSharding
|
|
9
|
+
from jax.sharding import PartitionSpec as P
|
|
10
|
+
from torch.nn.parameter import Parameter
|
|
11
|
+
from torchax.interop import call_jax, torch_view
|
|
12
|
+
from torchax.ops.mappings import t2j
|
|
13
|
+
from vllm.logger import init_logger
|
|
14
|
+
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
|
|
15
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
|
|
16
|
+
CompressedTensorsConfig
|
|
17
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
18
|
+
CompressedTensorsW8A8Fp8MoEMethod
|
|
19
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
|
20
|
+
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
|
21
|
+
|
|
22
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
23
|
+
|
|
24
|
+
logger = init_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
|
|
28
|
+
JaxCommonConfig):
|
|
29
|
+
|
|
30
|
+
def __init__(self, quant_config: "CompressedTensorsConfig",
|
|
31
|
+
moe: FusedMoEConfig, mesh: Mesh):
|
|
32
|
+
super().__init__(quant_config, moe)
|
|
33
|
+
self.mesh = mesh
|
|
34
|
+
self.quant_config = quant_config
|
|
35
|
+
|
|
36
|
+
# disable GPU paths
|
|
37
|
+
self.use_marlin = False
|
|
38
|
+
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
|
|
39
|
+
self.is_fp8_w8a8_sm100 = False
|
|
40
|
+
self.use_cutlass = False
|
|
41
|
+
self.disable_expert_map = False
|
|
42
|
+
|
|
43
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
44
|
+
assert isinstance(layer, FusedMoE)
|
|
45
|
+
|
|
46
|
+
intermediate_size = layer.w13_weight.shape[1] // 2
|
|
47
|
+
w1_weight = layer.w13_weight[:, :intermediate_size]
|
|
48
|
+
w3_weight = layer.w13_weight[:, intermediate_size:]
|
|
49
|
+
w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size]
|
|
50
|
+
w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:]
|
|
51
|
+
|
|
52
|
+
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
53
|
+
w2_weight_scale = t2j(layer.w2_weight_scale.to(torch.bfloat16),
|
|
54
|
+
use_dlpack=False)
|
|
55
|
+
w1_weight = t2j(w1_weight, use_dlpack=False)
|
|
56
|
+
w1_weight_scale = t2j(w1_weight_scale.to(torch.bfloat16),
|
|
57
|
+
use_dlpack=False)
|
|
58
|
+
w3_weight = t2j(w3_weight, use_dlpack=False)
|
|
59
|
+
w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16),
|
|
60
|
+
use_dlpack=False)
|
|
61
|
+
|
|
62
|
+
if layer.use_ep:
|
|
63
|
+
format = Format(Layout((0, 1, 2)),
|
|
64
|
+
NamedSharding(self.mesh, P("model", None, None)))
|
|
65
|
+
w1_weight = jax.device_put(w1_weight, format)
|
|
66
|
+
w1_weight_scale = jax.device_put(w1_weight_scale, format)
|
|
67
|
+
w3_weight = jax.device_put(w3_weight, format)
|
|
68
|
+
w3_weight_scale = jax.device_put(w3_weight_scale, format)
|
|
69
|
+
w2_weight = jax.device_put(w2_weight, format)
|
|
70
|
+
w2_weight_scale = jax.device_put(w2_weight_scale, format)
|
|
71
|
+
else:
|
|
72
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
73
|
+
n_shards = self.mesh.shape["model"]
|
|
74
|
+
assert intermediate_size % n_shards == 0
|
|
75
|
+
|
|
76
|
+
# TODO: enable this if using fused weights
|
|
77
|
+
# output_sizes = [intermediate_size, intermediate_size]
|
|
78
|
+
# w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
79
|
+
# w13_weight, output_sizes, n_shards, dim=1
|
|
80
|
+
# )
|
|
81
|
+
|
|
82
|
+
w13_format = Format(
|
|
83
|
+
Layout((0, 1, 2)),
|
|
84
|
+
NamedSharding(self.mesh, P(None, "model", None)))
|
|
85
|
+
w1_weight = jax.device_put(w1_weight, w13_format)
|
|
86
|
+
w1_weight_scale = jax.device_put(w1_weight_scale, w13_format)
|
|
87
|
+
w3_weight = jax.device_put(w3_weight, w13_format)
|
|
88
|
+
w3_weight_scale = jax.device_put(w3_weight_scale, w13_format)
|
|
89
|
+
w2_weight = jax.device_put(
|
|
90
|
+
w2_weight,
|
|
91
|
+
Format(Layout((0, 1, 2)),
|
|
92
|
+
NamedSharding(self.mesh, P(None, None, "model"))),
|
|
93
|
+
)
|
|
94
|
+
w2_weight_scale = jax.device_put(
|
|
95
|
+
w2_weight_scale,
|
|
96
|
+
Format(Layout((0, 1, 2)), NamedSharding(self.mesh, P())),
|
|
97
|
+
) # replicate
|
|
98
|
+
|
|
99
|
+
w1_weight = Parameter(torch_view(w1_weight), requires_grad=False)
|
|
100
|
+
w1_weight_scale = Parameter(torch_view(w1_weight_scale),
|
|
101
|
+
requires_grad=False)
|
|
102
|
+
w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
103
|
+
w2_weight_scale = Parameter(torch_view(w2_weight_scale),
|
|
104
|
+
requires_grad=False)
|
|
105
|
+
w3_weight = Parameter(torch_view(w3_weight), requires_grad=False)
|
|
106
|
+
w3_weight_scale = Parameter(torch_view(w3_weight_scale),
|
|
107
|
+
requires_grad=False)
|
|
108
|
+
|
|
109
|
+
# TODO dont reuse variable
|
|
110
|
+
layer.w13_weight = w1_weight
|
|
111
|
+
layer.w13_weight_scale = w1_weight_scale
|
|
112
|
+
layer.w2_weight = w2_weight
|
|
113
|
+
layer.w2_weight_scale = w2_weight_scale
|
|
114
|
+
layer.w3_weight = w3_weight
|
|
115
|
+
layer.w3_weight_scale = w3_weight_scale
|
|
116
|
+
|
|
117
|
+
def apply(
|
|
118
|
+
self,
|
|
119
|
+
layer: torch.nn.Module,
|
|
120
|
+
x: torch.Tensor,
|
|
121
|
+
router_logits: torch.Tensor,
|
|
122
|
+
top_k: int,
|
|
123
|
+
renormalize: bool,
|
|
124
|
+
use_grouped_topk: bool = False,
|
|
125
|
+
topk_group: Optional[int] = None,
|
|
126
|
+
num_expert_group: Optional[int] = None,
|
|
127
|
+
global_num_experts: int = -1,
|
|
128
|
+
expert_map: Optional[torch.Tensor] = None,
|
|
129
|
+
custom_routing_function: Optional[Callable] = None,
|
|
130
|
+
scoring_func: str = "softmax",
|
|
131
|
+
routed_scaling_factor: float = 1.0,
|
|
132
|
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
133
|
+
apply_router_weight_on_input: bool = False,
|
|
134
|
+
activation: str = "silu",
|
|
135
|
+
enable_eplb: bool = False,
|
|
136
|
+
expert_load_view: Optional[torch.Tensor] = None,
|
|
137
|
+
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
138
|
+
logical_replica_count: Optional[torch.Tensor] = None,
|
|
139
|
+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
140
|
+
assert isinstance(layer, FusedMoE)
|
|
141
|
+
if activation != "silu":
|
|
142
|
+
raise NotImplementedError(
|
|
143
|
+
"Only silu is supported for activation function.")
|
|
144
|
+
if scoring_func != "softmax":
|
|
145
|
+
raise NotImplementedError(
|
|
146
|
+
"Only softmax is supported for scoring_func")
|
|
147
|
+
|
|
148
|
+
# import sys
|
|
149
|
+
# sys.stdin = open(0)
|
|
150
|
+
# breakpoint()
|
|
151
|
+
|
|
152
|
+
# TODO: Use MoE kernel when it supports fp8
|
|
153
|
+
|
|
154
|
+
seqlen = x.shape[0]
|
|
155
|
+
|
|
156
|
+
expert_weights = F.softmax(router_logits, dim=-1)
|
|
157
|
+
expert_weights, expert_indices = torch.topk(expert_weights,
|
|
158
|
+
top_k,
|
|
159
|
+
dim=-1)
|
|
160
|
+
if renormalize:
|
|
161
|
+
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
|
|
162
|
+
|
|
163
|
+
# cond ffn
|
|
164
|
+
# e = total num of exp = 160
|
|
165
|
+
# t = seqlen
|
|
166
|
+
# o = config.imtermediate size
|
|
167
|
+
# i = config.dim
|
|
168
|
+
#torch.einsum("ti, eoi -> teo", x, layer.w13_weight) * self.w13_weight_scale)
|
|
169
|
+
ux1 = call_jax(jax.lax.dot,
|
|
170
|
+
x,
|
|
171
|
+
layer.w13_weight,
|
|
172
|
+
dimension_numbers=(((1, ), (2, )), ((), ())),
|
|
173
|
+
preferred_element_type=jnp.bfloat16.dtype)
|
|
174
|
+
x1 = F.silu(ux1 * layer.w13_weight_scale.squeeze(2))
|
|
175
|
+
|
|
176
|
+
#x3 = torch.einsum("ti, eoi -> teo", x, layer.w3_weight) * self.w3_weight_scale
|
|
177
|
+
x3 = call_jax(jax.lax.dot,
|
|
178
|
+
x,
|
|
179
|
+
layer.w3_weight,
|
|
180
|
+
dimension_numbers=(((1, ), (2, )), ((), ())),
|
|
181
|
+
preferred_element_type=jnp.bfloat16.dtype
|
|
182
|
+
) * layer.w3_weight_scale.squeeze(2)
|
|
183
|
+
|
|
184
|
+
#expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2_weight) * self.w2_weight_scale
|
|
185
|
+
expert_outs = call_jax(
|
|
186
|
+
jax.lax.dot,
|
|
187
|
+
x1 * x3,
|
|
188
|
+
layer.w2_weight,
|
|
189
|
+
dimension_numbers=(((2, ), (2, )), ((1, ), (0, ))),
|
|
190
|
+
preferred_element_type=jnp.bfloat16.dtype).transpose(
|
|
191
|
+
0, 1) * layer.w2_weight_scale.squeeze(2)
|
|
192
|
+
|
|
193
|
+
seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
|
|
194
|
+
expert_outs = expert_outs[seq_indexes, expert_indices]
|
|
195
|
+
|
|
196
|
+
# out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
|
|
197
|
+
out = call_jax(jax.lax.dot,
|
|
198
|
+
expert_outs,
|
|
199
|
+
expert_weights,
|
|
200
|
+
dimension_numbers=(((1, ), (1, )), ((0, ), (0, ))),
|
|
201
|
+
preferred_element_type=jnp.bfloat16.dtype)
|
|
202
|
+
|
|
203
|
+
return out
|
|
File without changes
|