tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
from compressed_tensors.quantization import (QuantizationArgs,
|
|
7
|
+
QuantizationStrategy)
|
|
8
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
|
+
from torchax.interop import jax_view, torch_view
|
|
10
|
+
from torchax.ops.mappings import t2j
|
|
11
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
|
|
12
|
+
CompressedTensorsW8A8Fp8
|
|
13
|
+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
|
|
14
|
+
per_tensor_dequantize
|
|
15
|
+
|
|
16
|
+
from tpu_inference.layers.vllm.linear_common import (
|
|
17
|
+
sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
|
|
18
|
+
torch_to_jax_param)
|
|
19
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
|
|
20
|
+
|
|
21
|
+
P = PartitionSpec
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def requantize_with_max_scale(
|
|
25
|
+
weight: torch.Tensor, weight_scale: torch.Tensor,
|
|
26
|
+
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
27
|
+
dtype = weight.dtype
|
|
28
|
+
dtype_info = torch.finfo(dtype)
|
|
29
|
+
maxval = float(dtype_info.max)
|
|
30
|
+
minval = float(dtype_info.min)
|
|
31
|
+
|
|
32
|
+
max_w_scale = weight_scale.max()
|
|
33
|
+
|
|
34
|
+
unfused_module_in_checkpoint = (weight_scale[-1]
|
|
35
|
+
> torch.finfo(torch.float8_e4m3fn).min)
|
|
36
|
+
|
|
37
|
+
# If unfused checkpoint, need requanize with the single scale.
|
|
38
|
+
if unfused_module_in_checkpoint:
|
|
39
|
+
start = 0
|
|
40
|
+
for idx, logical_width in enumerate(logical_widths):
|
|
41
|
+
# Skip any component with zero width.
|
|
42
|
+
if logical_width == 0:
|
|
43
|
+
continue
|
|
44
|
+
end = start + logical_width
|
|
45
|
+
weight_dq = per_tensor_dequantize(weight[start:end, :],
|
|
46
|
+
weight_scale[idx])
|
|
47
|
+
weight_q = weight_dq / max_w_scale
|
|
48
|
+
weight[start:end, :] = weight_q.clamp(minval, maxval).to(dtype)
|
|
49
|
+
start = end
|
|
50
|
+
|
|
51
|
+
return max_w_scale, weight
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
weight_quant: QuantizationArgs,
|
|
59
|
+
is_static_input_scheme: bool,
|
|
60
|
+
jax_config: JaxCommonLinearConfig,
|
|
61
|
+
):
|
|
62
|
+
super().__init__(weight_quant, is_static_input_scheme)
|
|
63
|
+
|
|
64
|
+
self.jax_config = jax_config
|
|
65
|
+
|
|
66
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
67
|
+
weight = layer.weight
|
|
68
|
+
weight_scale = layer.weight_scale
|
|
69
|
+
|
|
70
|
+
if self.is_static_input_scheme:
|
|
71
|
+
# In static quant, all input_scales share the same value.
|
|
72
|
+
assert layer.input_scale.min() == layer.input_scale.max()
|
|
73
|
+
input_scale_first = layer.input_scale[0]
|
|
74
|
+
|
|
75
|
+
input_scale = jax.device_put(
|
|
76
|
+
t2j(input_scale_first, use_dlpack=False),
|
|
77
|
+
NamedSharding(self.jax_config.mesh, P()))
|
|
78
|
+
input_scale = torch.nn.Parameter(torch_view(input_scale),
|
|
79
|
+
requires_grad=False)
|
|
80
|
+
delattr(layer, "input_scale")
|
|
81
|
+
layer.input_scale = input_scale
|
|
82
|
+
|
|
83
|
+
# TODO(kyuyeunk): Investigate performance gain from merging scales.
|
|
84
|
+
# By merging input and weight scales, we reduce the number of muls
|
|
85
|
+
# required for dequantization from 2 (for each scales) to 1.
|
|
86
|
+
# weight_scale *= input_scale_first
|
|
87
|
+
|
|
88
|
+
if self.strategy == QuantizationStrategy.TENSOR:
|
|
89
|
+
weight_scale, weight = requantize_with_max_scale(
|
|
90
|
+
weight, weight_scale, self.jax_config.output_sizes)
|
|
91
|
+
weight_scale = jax.device_put(
|
|
92
|
+
t2j(weight_scale, use_dlpack=False),
|
|
93
|
+
NamedSharding(self.jax_config.mesh, P()))
|
|
94
|
+
weight_scale = torch.nn.Parameter(torch_view(weight_scale),
|
|
95
|
+
requires_grad=False)
|
|
96
|
+
else:
|
|
97
|
+
weight_scale = weight_scale.squeeze(-1)
|
|
98
|
+
weight_scale = torch_to_jax_param(
|
|
99
|
+
weight_scale,
|
|
100
|
+
NamedSharding(self.jax_config.mesh,
|
|
101
|
+
self.jax_config.bias_sharding),
|
|
102
|
+
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
103
|
+
self.jax_config.fuse_matmuls)
|
|
104
|
+
delattr(layer, "weight_scale")
|
|
105
|
+
layer.weight_scale = weight_scale
|
|
106
|
+
|
|
107
|
+
weight = torch_to_jax_param(
|
|
108
|
+
layer.weight,
|
|
109
|
+
NamedSharding(self.jax_config.mesh,
|
|
110
|
+
self.jax_config.weight_sharding),
|
|
111
|
+
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
112
|
+
self.jax_config.fuse_matmuls)
|
|
113
|
+
delattr(layer, "weight")
|
|
114
|
+
layer.weight = weight
|
|
115
|
+
|
|
116
|
+
if layer.bias is not None:
|
|
117
|
+
bias = torch_to_jax_param(
|
|
118
|
+
layer.bias,
|
|
119
|
+
NamedSharding(self.jax_config.mesh,
|
|
120
|
+
self.jax_config.bias_sharding),
|
|
121
|
+
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
122
|
+
self.jax_config.fuse_matmuls)
|
|
123
|
+
delattr(layer, "bias")
|
|
124
|
+
layer.bias = bias
|
|
125
|
+
|
|
126
|
+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
127
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
128
|
+
with jax.named_scope(layer._get_name()):
|
|
129
|
+
if self.jax_config.fuse_matmuls:
|
|
130
|
+
return self._apply_fused(layer, x, bias)
|
|
131
|
+
else:
|
|
132
|
+
return self._apply_split(layer, x, bias)
|
|
133
|
+
|
|
134
|
+
def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
135
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
136
|
+
x_jax = jax_view(x)
|
|
137
|
+
weight_jax = jax_view(layer.weight)
|
|
138
|
+
weight_scale_jax = jax_view(layer.weight_scale)
|
|
139
|
+
|
|
140
|
+
if self.is_static_input_scheme:
|
|
141
|
+
# TODO(kyuyeunk): Add kernel support for static quant
|
|
142
|
+
input_scale = jax_view(layer.input_scale)
|
|
143
|
+
dtype_info = jnp.finfo(weight_jax.dtype)
|
|
144
|
+
maxval = float(dtype_info.max)
|
|
145
|
+
minval = float(dtype_info.min)
|
|
146
|
+
x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
|
|
147
|
+
maxval).astype(weight_jax.dtype)
|
|
148
|
+
|
|
149
|
+
outs = jax.lax.dot_general(
|
|
150
|
+
x_q,
|
|
151
|
+
weight_jax,
|
|
152
|
+
(((1, ), (1, )), ((), ())),
|
|
153
|
+
preferred_element_type=jnp.float32,
|
|
154
|
+
)
|
|
155
|
+
outs *= weight_scale_jax
|
|
156
|
+
outs = outs.astype(x_jax.dtype)
|
|
157
|
+
else:
|
|
158
|
+
outs = sharded_quantized_matmul(x_jax, weight_jax,
|
|
159
|
+
weight_scale_jax,
|
|
160
|
+
self.jax_config.mesh,
|
|
161
|
+
self.jax_config.weight_sharding)
|
|
162
|
+
|
|
163
|
+
if bias is not None and not layer.skip_bias_add:
|
|
164
|
+
outs += jax_view(bias)
|
|
165
|
+
outs = slice_sharded_tensor_for_concatenation(
|
|
166
|
+
outs, self.jax_config.output_sizes, self.jax_config.n_shards)
|
|
167
|
+
return torch_view(jnp.concatenate(outs, axis=-1))
|
|
168
|
+
|
|
169
|
+
def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
170
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
171
|
+
assert isinstance(layer.weight, torch.nn.ParameterList)
|
|
172
|
+
|
|
173
|
+
x_jax = jax_view(x)
|
|
174
|
+
outs = []
|
|
175
|
+
for i, (weight, weight_scale) in enumerate(
|
|
176
|
+
zip(layer.weight, layer.weight_scale)):
|
|
177
|
+
weight_jax = jax_view(weight)
|
|
178
|
+
weight_scale_jax = jax_view(weight_scale)
|
|
179
|
+
|
|
180
|
+
if self.is_static_input_scheme:
|
|
181
|
+
# TODO(kyuyeunk): Add kernel support for static quant
|
|
182
|
+
input_scale = jax_view(layer.input_scale)
|
|
183
|
+
dtype_info = jnp.finfo(weight_jax.dtype)
|
|
184
|
+
maxval = float(dtype_info.max)
|
|
185
|
+
minval = float(dtype_info.min)
|
|
186
|
+
x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
|
|
187
|
+
maxval).astype(weight_jax.dtype)
|
|
188
|
+
|
|
189
|
+
out = jax.lax.dot_general(
|
|
190
|
+
x_q,
|
|
191
|
+
weight_jax,
|
|
192
|
+
(((1, ), (1, )), ((), ())),
|
|
193
|
+
preferred_element_type=jnp.float32,
|
|
194
|
+
)
|
|
195
|
+
# TODO(kyuyeunk): Investigate performance gain from merging scales.
|
|
196
|
+
# out *= weight_scale_jax
|
|
197
|
+
out *= weight_scale_jax * input_scale
|
|
198
|
+
out = out.astype(x_jax.dtype)
|
|
199
|
+
else:
|
|
200
|
+
out = sharded_quantized_matmul(x_jax, weight_jax,
|
|
201
|
+
weight_scale_jax,
|
|
202
|
+
self.jax_config.mesh,
|
|
203
|
+
self.jax_config.weight_sharding)
|
|
204
|
+
|
|
205
|
+
if bias is not None and not layer.skip_bias_add:
|
|
206
|
+
out += jax_view(bias[i])
|
|
207
|
+
outs.append(out)
|
|
208
|
+
return torch_view(jnp.concatenate(outs, axis=-1))
|
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
from compressed_tensors.quantization import QuantizationStrategy
|
|
7
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
8
|
+
from torchax.interop import jax_view, torch_view
|
|
9
|
+
from vllm.logger import init_logger
|
|
10
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
11
|
+
CompressedTensorsW8A8Int8
|
|
12
|
+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
|
|
13
|
+
convert_to_channelwise
|
|
14
|
+
|
|
15
|
+
from tpu_inference.layers.vllm.linear_common import (
|
|
16
|
+
sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
|
|
17
|
+
torch_to_jax_param)
|
|
18
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
|
|
19
|
+
|
|
20
|
+
P = PartitionSpec
|
|
21
|
+
logger = init_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
|
|
25
|
+
|
|
26
|
+
def __init__(self, strategy: str, is_static_input_scheme: bool,
|
|
27
|
+
input_symmetric: bool, jax_config: JaxCommonLinearConfig):
|
|
28
|
+
super().__init__(strategy, is_static_input_scheme, input_symmetric)
|
|
29
|
+
|
|
30
|
+
self.jax_config = jax_config
|
|
31
|
+
self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL),
|
|
32
|
+
|
|
33
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
34
|
+
weight = torch_to_jax_param(
|
|
35
|
+
layer.weight,
|
|
36
|
+
NamedSharding(self.jax_config.mesh,
|
|
37
|
+
self.jax_config.weight_sharding),
|
|
38
|
+
self.jax_config.output_sizes,
|
|
39
|
+
self.jax_config.n_shards,
|
|
40
|
+
self.jax_config.fuse_matmuls,
|
|
41
|
+
)
|
|
42
|
+
delattr(layer, "weight")
|
|
43
|
+
layer.weight = weight
|
|
44
|
+
|
|
45
|
+
weight_scale = layer.weight_scale
|
|
46
|
+
is_fused_module = len(layer.logical_widths) > 1
|
|
47
|
+
if is_fused_module and not self.is_channelwise:
|
|
48
|
+
weight_scale = convert_to_channelwise(weight_scale,
|
|
49
|
+
layer.logical_widths)
|
|
50
|
+
weight_scale = weight_scale.squeeze(-1)
|
|
51
|
+
|
|
52
|
+
weight_scale = torch_to_jax_param(
|
|
53
|
+
weight_scale,
|
|
54
|
+
NamedSharding(self.jax_config.mesh, self.jax_config.bias_sharding),
|
|
55
|
+
self.jax_config.output_sizes,
|
|
56
|
+
self.jax_config.n_shards,
|
|
57
|
+
self.jax_config.fuse_matmuls,
|
|
58
|
+
)
|
|
59
|
+
delattr(layer, "weight_scale")
|
|
60
|
+
layer.weight_scale = weight_scale
|
|
61
|
+
|
|
62
|
+
if layer.bias is not None and not layer.skip_bias_add:
|
|
63
|
+
if layer.return_bias:
|
|
64
|
+
logger.warning_once("Bias might return incorrect value.")
|
|
65
|
+
|
|
66
|
+
bias = torch_to_jax_param(
|
|
67
|
+
layer.bias,
|
|
68
|
+
NamedSharding(self.jax_config.mesh,
|
|
69
|
+
self.jax_config.bias_sharding),
|
|
70
|
+
self.jax_config.output_sizes,
|
|
71
|
+
self.jax_config.n_shards,
|
|
72
|
+
self.jax_config.fuse_matmuls,
|
|
73
|
+
)
|
|
74
|
+
delattr(layer, "bias")
|
|
75
|
+
layer.bias = bias
|
|
76
|
+
|
|
77
|
+
# TODO(kyuyeunk): Support static range input quantization.
|
|
78
|
+
assert getattr(layer, "input_scale", None) is None
|
|
79
|
+
assert getattr(layer, "input_zero_point", None) is None
|
|
80
|
+
assert getattr(layer, "azp_adj", None) is None
|
|
81
|
+
|
|
82
|
+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
83
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
84
|
+
with jax.named_scope(layer._get_name()):
|
|
85
|
+
if self.jax_config.fuse_matmuls:
|
|
86
|
+
out = self._apply_fused(layer, x, bias)
|
|
87
|
+
else:
|
|
88
|
+
out = self._apply_split(layer, x, bias)
|
|
89
|
+
|
|
90
|
+
return out
|
|
91
|
+
|
|
92
|
+
def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
93
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
94
|
+
x_jax = jax_view(x)
|
|
95
|
+
weight_jax = jax_view(layer.weight)
|
|
96
|
+
weight_scale_jax = jax_view(layer.weight_scale)
|
|
97
|
+
|
|
98
|
+
outs = sharded_quantized_matmul(
|
|
99
|
+
x_jax,
|
|
100
|
+
weight_jax,
|
|
101
|
+
weight_scale_jax,
|
|
102
|
+
self.jax_config.mesh,
|
|
103
|
+
self.jax_config.weight_sharding,
|
|
104
|
+
)
|
|
105
|
+
if bias is not None and not layer.skip_bias_add:
|
|
106
|
+
outs += jax_view(bias)
|
|
107
|
+
|
|
108
|
+
outs = slice_sharded_tensor_for_concatenation(
|
|
109
|
+
outs, self.jax_config.output_sizes, self.jax_config.n_shards)
|
|
110
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
111
|
+
return torch_view(out)
|
|
112
|
+
|
|
113
|
+
def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
114
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
115
|
+
assert isinstance(layer.weight, torch.nn.ParameterList)
|
|
116
|
+
|
|
117
|
+
x_jax = jax_view(x)
|
|
118
|
+
outs = []
|
|
119
|
+
for i, (weight, weight_scale) in enumerate(
|
|
120
|
+
zip(layer.weight, layer.weight_scale)):
|
|
121
|
+
weight_jax = jax_view(weight)
|
|
122
|
+
weight_scale_jax = jax_view(weight_scale)
|
|
123
|
+
|
|
124
|
+
out = sharded_quantized_matmul(
|
|
125
|
+
x_jax,
|
|
126
|
+
weight_jax,
|
|
127
|
+
weight_scale_jax,
|
|
128
|
+
self.jax_config.mesh,
|
|
129
|
+
self.jax_config.weight_sharding,
|
|
130
|
+
)
|
|
131
|
+
if bias is not None and not layer.skip_bias_add:
|
|
132
|
+
out += jax_view(bias[i])
|
|
133
|
+
|
|
134
|
+
outs.append(out)
|
|
135
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
136
|
+
return torch_view(out)
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from typing import Callable, Optional, Union
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
from jax.experimental.layout import Format, Layout
|
|
7
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
+
from torch.nn.parameter import Parameter
|
|
9
|
+
from torchax.interop import jax_view, torch_view
|
|
10
|
+
from torchax.ops.mappings import t2j
|
|
11
|
+
from vllm.logger import init_logger
|
|
12
|
+
from vllm.model_executor.layers.fused_moe.config import (
|
|
13
|
+
FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
|
|
14
|
+
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
|
15
|
+
FusedMoEMethodBase)
|
|
16
|
+
from vllm.model_executor.layers.linear import LinearBase
|
|
17
|
+
from vllm.model_executor.layers.quantization import \
|
|
18
|
+
register_quantization_config
|
|
19
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
20
|
+
QuantizeMethodBase
|
|
21
|
+
from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
|
|
22
|
+
Mxfp4Config,
|
|
23
|
+
Mxfp4MoEMethod)
|
|
24
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
25
|
+
is_layer_skipped
|
|
26
|
+
|
|
27
|
+
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
28
|
+
get_tpu_quant_method)
|
|
29
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
|
|
30
|
+
from tpu_inference.layers.vllm.linear_common import \
|
|
31
|
+
reorder_concatenated_tensor_for_sharding
|
|
32
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
33
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
34
|
+
VllmUnquantizedLinearMethod
|
|
35
|
+
|
|
36
|
+
MXFP4_BLOCK_SIZE = 32
|
|
37
|
+
|
|
38
|
+
P = PartitionSpec
|
|
39
|
+
logger = init_logger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# TODO(kyuyeunk): Move these functions into a common utility file.
|
|
43
|
+
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
44
|
+
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
45
|
+
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
46
|
+
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
47
|
+
# we flatten them with the last dim.
|
|
48
|
+
return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
52
|
+
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
53
|
+
exponents = u8.astype(jnp.int32) + e8_finfo.minexp
|
|
54
|
+
ones = jnp.ones_like(u8, dtype=jnp.float32)
|
|
55
|
+
return jnp.ldexp(ones, exponents)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def dequantize_block_weight(weight: jax.Array,
|
|
59
|
+
scale: jax.Array,
|
|
60
|
+
block_size: int,
|
|
61
|
+
out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
|
|
62
|
+
orig_shape = weight.shape
|
|
63
|
+
weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
|
|
64
|
+
weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
|
|
65
|
+
scale, -1)
|
|
66
|
+
return weight_dequantized.reshape(orig_shape).astype(out_dtype)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@register_quantization_config(get_tpu_quant_method(MXFP4))
|
|
70
|
+
class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def get_name(cls):
|
|
74
|
+
return MXFP4
|
|
75
|
+
|
|
76
|
+
def get_quant_method(self, layer: torch.nn.Module,
|
|
77
|
+
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
78
|
+
from vllm.attention.layer import Attention # Avoid circular import
|
|
79
|
+
|
|
80
|
+
if isinstance(layer, LinearBase):
|
|
81
|
+
linear_config = self.get_linear_config(layer)
|
|
82
|
+
if self.ignored_layers and is_layer_skipped(
|
|
83
|
+
prefix=prefix,
|
|
84
|
+
ignored_layers=self.ignored_layers,
|
|
85
|
+
fused_mapping=self.packed_modules_mapping,
|
|
86
|
+
):
|
|
87
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
88
|
+
# TODO: Add support for MXFP4 Linear Method.
|
|
89
|
+
# MXFP4 LinearMethod is available in AMD-Quark, refer to that
|
|
90
|
+
# implementation if you are interested in enabling MXFP4 here.
|
|
91
|
+
logger.warning_once(
|
|
92
|
+
"MXFP4 linear layer is not implemented - falling back to "
|
|
93
|
+
"UnquantizedLinearMethod.")
|
|
94
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
95
|
+
elif isinstance(layer, FusedMoE):
|
|
96
|
+
return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
|
|
97
|
+
elif isinstance(layer, Attention):
|
|
98
|
+
# TODO: Add support for MXFP4 Attention.
|
|
99
|
+
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
100
|
+
"Skipping quantization for this layer.")
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
105
|
+
|
|
106
|
+
def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
|
|
107
|
+
FusedMoEMethodBase.__init__(self, moe)
|
|
108
|
+
|
|
109
|
+
# We piggyback on triton implementation as it applies minimal hardware
|
|
110
|
+
# specific post processing to the weights.
|
|
111
|
+
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
112
|
+
self.mesh = mesh
|
|
113
|
+
|
|
114
|
+
def get_fused_moe_quant_config(
|
|
115
|
+
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
116
|
+
# Because we have dequantized weights, we only need biased moe config.
|
|
117
|
+
# TODO(kyuyeunk): Add native support for MXFP4.
|
|
118
|
+
return biased_moe_quant_config(
|
|
119
|
+
layer.w13_bias,
|
|
120
|
+
layer.w2_bias,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
124
|
+
assert isinstance(layer, FusedMoE)
|
|
125
|
+
|
|
126
|
+
w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
|
|
127
|
+
w13_weight_scale = e8m0_to_fp32(
|
|
128
|
+
t2j(layer.w13_weight_scale, use_dlpack=False))
|
|
129
|
+
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
130
|
+
|
|
131
|
+
w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
|
|
132
|
+
w2_weight_scale = e8m0_to_fp32(
|
|
133
|
+
t2j(layer.w2_weight_scale, use_dlpack=False))
|
|
134
|
+
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
135
|
+
|
|
136
|
+
# We dequantize fp4 weights into bf16.
|
|
137
|
+
# TODO(kyuyeunk): Add native support for MXFP4.
|
|
138
|
+
w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
|
|
139
|
+
MXFP4_BLOCK_SIZE, jnp.bfloat16)
|
|
140
|
+
w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
|
|
141
|
+
MXFP4_BLOCK_SIZE, jnp.bfloat16)
|
|
142
|
+
|
|
143
|
+
# Because we have dequantized weights, scales are not used anymore.
|
|
144
|
+
delattr(layer, "w13_weight_scale")
|
|
145
|
+
delattr(layer, "w2_weight_scale")
|
|
146
|
+
|
|
147
|
+
if layer.activation == "swigluoai":
|
|
148
|
+
# When using swigluoai, vLLM splits gmm output in a interleaved way.
|
|
149
|
+
# However, interleaved split is not performant on TPU. Therefore,
|
|
150
|
+
# we preprocess the weight so that splitting gmm output by middle
|
|
151
|
+
# can still get the same result.
|
|
152
|
+
w1_weight = w13_weight[:, ::2, :]
|
|
153
|
+
w3_weight = w13_weight[:, 1::2, :]
|
|
154
|
+
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
155
|
+
|
|
156
|
+
w1_bias = w13_bias[:, ::2]
|
|
157
|
+
w3_bias = w13_bias[:, 1::2]
|
|
158
|
+
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
159
|
+
|
|
160
|
+
# TODO(kyuyeunk): Add weight processing logic for the new kernel.
|
|
161
|
+
if layer.use_ep:
|
|
162
|
+
w13_weight = jax.device_put(
|
|
163
|
+
w13_weight,
|
|
164
|
+
Format(Layout((0, 1, 2)),
|
|
165
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
166
|
+
w2_weight = jax.device_put(
|
|
167
|
+
w2_weight,
|
|
168
|
+
Format(Layout((0, 1, 2)),
|
|
169
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
170
|
+
|
|
171
|
+
w13_bias = jax.device_put(
|
|
172
|
+
w13_bias,
|
|
173
|
+
Format(Layout((0, 1)),
|
|
174
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
175
|
+
w2_bias = jax.device_put(
|
|
176
|
+
w2_bias,
|
|
177
|
+
Format(Layout((0, 1)),
|
|
178
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
179
|
+
|
|
180
|
+
else:
|
|
181
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
182
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
183
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
184
|
+
n_shards = self.mesh.shape["model"]
|
|
185
|
+
assert intermediate_size % n_shards == 0
|
|
186
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
|
|
187
|
+
output_sizes,
|
|
188
|
+
n_shards,
|
|
189
|
+
dim=1)
|
|
190
|
+
w13_weight = jax.device_put(
|
|
191
|
+
w13_weight,
|
|
192
|
+
Format(Layout((0, 1, 2)),
|
|
193
|
+
NamedSharding(self.mesh, P(None, "model", None))))
|
|
194
|
+
w2_weight = jax.device_put(
|
|
195
|
+
w2_weight,
|
|
196
|
+
Format(Layout((0, 1, 2)),
|
|
197
|
+
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
198
|
+
|
|
199
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
|
|
200
|
+
output_sizes,
|
|
201
|
+
n_shards,
|
|
202
|
+
dim=1)
|
|
203
|
+
w13_bias = jax.device_put(
|
|
204
|
+
w13_bias,
|
|
205
|
+
Format(Layout((0, 1)),
|
|
206
|
+
NamedSharding(self.mesh, P(None, "model"))))
|
|
207
|
+
w2_bias = jax.device_put(
|
|
208
|
+
w2_bias,
|
|
209
|
+
Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
|
|
210
|
+
None))))
|
|
211
|
+
|
|
212
|
+
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
213
|
+
requires_grad=False)
|
|
214
|
+
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
215
|
+
|
|
216
|
+
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
217
|
+
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
218
|
+
|
|
219
|
+
pass
|
|
220
|
+
|
|
221
|
+
def apply(
|
|
222
|
+
self,
|
|
223
|
+
layer: torch.nn.Module,
|
|
224
|
+
x: torch.Tensor,
|
|
225
|
+
router_logits: torch.Tensor,
|
|
226
|
+
top_k: int,
|
|
227
|
+
renormalize: bool,
|
|
228
|
+
use_grouped_topk: bool = False,
|
|
229
|
+
topk_group: Optional[int] = None,
|
|
230
|
+
num_expert_group: Optional[int] = None,
|
|
231
|
+
global_num_experts: int = -1,
|
|
232
|
+
expert_map: Optional[torch.Tensor] = None,
|
|
233
|
+
custom_routing_function: Optional[Callable] = None,
|
|
234
|
+
scoring_func: str = "softmax",
|
|
235
|
+
routed_scaling_factor: float = 1.0,
|
|
236
|
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
237
|
+
apply_router_weight_on_input: bool = False,
|
|
238
|
+
activation: str = "silu",
|
|
239
|
+
enable_eplb: bool = False,
|
|
240
|
+
expert_load_view: Optional[torch.Tensor] = None,
|
|
241
|
+
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
242
|
+
logical_replica_count: Optional[torch.Tensor] = None,
|
|
243
|
+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
244
|
+
assert isinstance(layer, FusedMoE)
|
|
245
|
+
if scoring_func != "softmax":
|
|
246
|
+
raise NotImplementedError(
|
|
247
|
+
"Only softmax is supported for scoring_func")
|
|
248
|
+
|
|
249
|
+
# Use the original implementation
|
|
250
|
+
output = fused_moe_func_padded(
|
|
251
|
+
jax_view(x),
|
|
252
|
+
jax_view(layer.w13_weight),
|
|
253
|
+
jax_view(layer.w2_weight),
|
|
254
|
+
jax_view(layer.w13_bias) if self.moe.has_bias else None,
|
|
255
|
+
jax_view(layer.w2_bias) if self.moe.has_bias else None,
|
|
256
|
+
jax_view(router_logits),
|
|
257
|
+
topk=top_k,
|
|
258
|
+
global_num_experts=global_num_experts,
|
|
259
|
+
renormalize=renormalize,
|
|
260
|
+
reduce_results=layer.reduce_results,
|
|
261
|
+
mesh=self.mesh,
|
|
262
|
+
use_ep=layer.use_ep,
|
|
263
|
+
activation=activation,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return torch_view(output)
|