tpu-inference 0.11.1rc2__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +1 -1
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
from compressed_tensors.quantization import QuantizationStrategy
|
|
7
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
8
|
+
from torchax.interop import jax_view, torch_view
|
|
9
|
+
from vllm.logger import init_logger
|
|
10
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
11
|
+
CompressedTensorsW8A8Int8
|
|
12
|
+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
|
|
13
|
+
convert_to_channelwise
|
|
14
|
+
|
|
15
|
+
from tpu_inference.layers.vllm.linear_common import (
|
|
16
|
+
sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
|
|
17
|
+
torch_to_jax_param)
|
|
18
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
|
|
19
|
+
|
|
20
|
+
P = PartitionSpec
|
|
21
|
+
logger = init_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
|
|
25
|
+
|
|
26
|
+
def __init__(self, strategy: str, is_static_input_scheme: bool,
|
|
27
|
+
input_symmetric: bool, jax_config: JaxCommonLinearConfig):
|
|
28
|
+
super().__init__(strategy, is_static_input_scheme, input_symmetric)
|
|
29
|
+
|
|
30
|
+
self.jax_config = jax_config
|
|
31
|
+
self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL),
|
|
32
|
+
|
|
33
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
34
|
+
weight = torch_to_jax_param(
|
|
35
|
+
layer.weight,
|
|
36
|
+
NamedSharding(self.jax_config.mesh,
|
|
37
|
+
self.jax_config.weight_sharding),
|
|
38
|
+
self.jax_config.output_sizes,
|
|
39
|
+
self.jax_config.n_shards,
|
|
40
|
+
self.jax_config.fuse_matmuls,
|
|
41
|
+
)
|
|
42
|
+
delattr(layer, "weight")
|
|
43
|
+
layer.weight = weight
|
|
44
|
+
|
|
45
|
+
weight_scale = layer.weight_scale
|
|
46
|
+
is_fused_module = len(layer.logical_widths) > 1
|
|
47
|
+
if is_fused_module and not self.is_channelwise:
|
|
48
|
+
weight_scale = convert_to_channelwise(weight_scale,
|
|
49
|
+
layer.logical_widths)
|
|
50
|
+
weight_scale = weight_scale.squeeze(-1)
|
|
51
|
+
|
|
52
|
+
weight_scale = torch_to_jax_param(
|
|
53
|
+
weight_scale,
|
|
54
|
+
NamedSharding(self.jax_config.mesh, self.jax_config.bias_sharding),
|
|
55
|
+
self.jax_config.output_sizes,
|
|
56
|
+
self.jax_config.n_shards,
|
|
57
|
+
self.jax_config.fuse_matmuls,
|
|
58
|
+
)
|
|
59
|
+
delattr(layer, "weight_scale")
|
|
60
|
+
layer.weight_scale = weight_scale
|
|
61
|
+
|
|
62
|
+
if layer.bias is not None and not layer.skip_bias_add:
|
|
63
|
+
if layer.return_bias:
|
|
64
|
+
logger.warning_once("Bias might return incorrect value.")
|
|
65
|
+
|
|
66
|
+
bias = torch_to_jax_param(
|
|
67
|
+
layer.bias,
|
|
68
|
+
NamedSharding(self.jax_config.mesh,
|
|
69
|
+
self.jax_config.bias_sharding),
|
|
70
|
+
self.jax_config.output_sizes,
|
|
71
|
+
self.jax_config.n_shards,
|
|
72
|
+
self.jax_config.fuse_matmuls,
|
|
73
|
+
)
|
|
74
|
+
delattr(layer, "bias")
|
|
75
|
+
layer.bias = bias
|
|
76
|
+
|
|
77
|
+
# TODO(kyuyeunk): Support static range input quantization.
|
|
78
|
+
assert getattr(layer, "input_scale", None) is None
|
|
79
|
+
assert getattr(layer, "input_zero_point", None) is None
|
|
80
|
+
assert getattr(layer, "azp_adj", None) is None
|
|
81
|
+
|
|
82
|
+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
83
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
84
|
+
with jax.named_scope(layer._get_name()):
|
|
85
|
+
if self.jax_config.fuse_matmuls:
|
|
86
|
+
out = self._apply_fused(layer, x, bias)
|
|
87
|
+
else:
|
|
88
|
+
out = self._apply_split(layer, x, bias)
|
|
89
|
+
|
|
90
|
+
return out
|
|
91
|
+
|
|
92
|
+
def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
93
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
94
|
+
x_jax = jax_view(x)
|
|
95
|
+
weight_jax = jax_view(layer.weight)
|
|
96
|
+
weight_scale_jax = jax_view(layer.weight_scale)
|
|
97
|
+
|
|
98
|
+
outs = sharded_quantized_matmul(
|
|
99
|
+
x_jax,
|
|
100
|
+
weight_jax,
|
|
101
|
+
weight_scale_jax,
|
|
102
|
+
self.jax_config.mesh,
|
|
103
|
+
self.jax_config.weight_sharding,
|
|
104
|
+
)
|
|
105
|
+
if bias is not None and not layer.skip_bias_add:
|
|
106
|
+
outs += jax_view(bias)
|
|
107
|
+
|
|
108
|
+
outs = slice_sharded_tensor_for_concatenation(
|
|
109
|
+
outs, self.jax_config.output_sizes, self.jax_config.n_shards)
|
|
110
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
111
|
+
return torch_view(out)
|
|
112
|
+
|
|
113
|
+
def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
114
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
115
|
+
assert isinstance(layer.weight, torch.nn.ParameterList)
|
|
116
|
+
|
|
117
|
+
x_jax = jax_view(x)
|
|
118
|
+
outs = []
|
|
119
|
+
for i, (weight, weight_scale) in enumerate(
|
|
120
|
+
zip(layer.weight, layer.weight_scale)):
|
|
121
|
+
weight_jax = jax_view(weight)
|
|
122
|
+
weight_scale_jax = jax_view(weight_scale)
|
|
123
|
+
|
|
124
|
+
out = sharded_quantized_matmul(
|
|
125
|
+
x_jax,
|
|
126
|
+
weight_jax,
|
|
127
|
+
weight_scale_jax,
|
|
128
|
+
self.jax_config.mesh,
|
|
129
|
+
self.jax_config.weight_sharding,
|
|
130
|
+
)
|
|
131
|
+
if bias is not None and not layer.skip_bias_add:
|
|
132
|
+
out += jax_view(bias[i])
|
|
133
|
+
|
|
134
|
+
outs.append(out)
|
|
135
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
136
|
+
return torch_view(out)
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from typing import Any, Callable, Optional, Union
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import torch
|
|
7
|
+
from jax.experimental.layout import Format, Layout
|
|
8
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
9
|
+
from torch.nn.parameter import Parameter
|
|
10
|
+
from torchax.interop import jax_view, torch_view
|
|
11
|
+
from torchax.ops.mappings import t2j
|
|
12
|
+
from vllm.attention.layer import Attention
|
|
13
|
+
from vllm.logger import init_logger
|
|
14
|
+
from vllm.model_executor.layers.fused_moe.layer import (
|
|
15
|
+
FusedMoE, FusedMoEConfig, UnquantizedFusedMoEMethod)
|
|
16
|
+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|
17
|
+
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
|
18
|
+
from vllm.model_executor.layers.linear import (LinearBase,
|
|
19
|
+
UnquantizedLinearMethod)
|
|
20
|
+
from vllm.model_executor.layers.quantization import \
|
|
21
|
+
register_quantization_config
|
|
22
|
+
from vllm.model_executor.layers.quantization.base_config import (
|
|
23
|
+
QuantizationConfig, QuantizeMethodBase)
|
|
24
|
+
|
|
25
|
+
from tpu_inference.layers.vllm.fused_moe import jax_fused_moe_func_padded
|
|
26
|
+
from tpu_inference.layers.vllm.linear_common import (
|
|
27
|
+
reorder_concatenated_tensor_for_sharding,
|
|
28
|
+
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
29
|
+
from tpu_inference.layers.vllm.quantization.common import (
|
|
30
|
+
JaxCommonConfig, JaxCommonLinearConfig)
|
|
31
|
+
|
|
32
|
+
P = PartitionSpec
|
|
33
|
+
logger = init_logger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@register_quantization_config("jax-unquantized")
|
|
37
|
+
class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def get_name(cls) -> str:
|
|
41
|
+
return "jax-unquantized"
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
45
|
+
return [torch.float32, torch.float16, torch.bfloat16]
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def get_min_capability(cls) -> int:
|
|
49
|
+
return 0 # Always supported
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def get_config_filenames(cls) -> list[str]:
|
|
53
|
+
return [] # No extra configs required.
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_config(cls, _: dict[str, Any]) -> "VllmUnquantizedConfig":
|
|
57
|
+
return cls()
|
|
58
|
+
|
|
59
|
+
def get_quant_method(self, layer: torch.nn.Module,
|
|
60
|
+
prefix: str) -> Optional[QuantizeMethodBase]:
|
|
61
|
+
if isinstance(layer, LinearBase):
|
|
62
|
+
linear_config = self.get_linear_config(layer)
|
|
63
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
64
|
+
if isinstance(layer, FusedMoE):
|
|
65
|
+
moe_config = self.get_moe_config(layer)
|
|
66
|
+
return VllmUnquantizedFusedMoEMethod(moe_config, self.mesh)
|
|
67
|
+
if isinstance(layer, Attention):
|
|
68
|
+
return None
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
73
|
+
|
|
74
|
+
def __init__(self, jax_config: JaxCommonLinearConfig):
|
|
75
|
+
self.jax_config = jax_config
|
|
76
|
+
|
|
77
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
78
|
+
weight = torch_to_jax_param(
|
|
79
|
+
layer.weight,
|
|
80
|
+
NamedSharding(self.jax_config.mesh,
|
|
81
|
+
self.jax_config.weight_sharding),
|
|
82
|
+
self.jax_config.output_sizes,
|
|
83
|
+
self.jax_config.n_shards,
|
|
84
|
+
self.jax_config.fuse_matmuls,
|
|
85
|
+
)
|
|
86
|
+
delattr(layer, "weight")
|
|
87
|
+
layer.weight = weight
|
|
88
|
+
|
|
89
|
+
if layer.bias is not None and not layer.skip_bias_add:
|
|
90
|
+
if layer.return_bias:
|
|
91
|
+
logger.warning_once("Bias might return incorrect value.")
|
|
92
|
+
|
|
93
|
+
bias = torch_to_jax_param(
|
|
94
|
+
layer.bias,
|
|
95
|
+
NamedSharding(self.jax_config.mesh,
|
|
96
|
+
self.jax_config.bias_sharding),
|
|
97
|
+
self.jax_config.output_sizes,
|
|
98
|
+
self.jax_config.n_shards,
|
|
99
|
+
self.jax_config.fuse_matmuls,
|
|
100
|
+
)
|
|
101
|
+
delattr(layer, "bias")
|
|
102
|
+
layer.bias = bias
|
|
103
|
+
|
|
104
|
+
def apply(self,
|
|
105
|
+
layer: torch.nn.Module,
|
|
106
|
+
x: torch.Tensor,
|
|
107
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
108
|
+
with jax.named_scope(layer._get_name()):
|
|
109
|
+
if in_sharding := self.jax_config.get_input_sharding(x):
|
|
110
|
+
x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
|
|
111
|
+
|
|
112
|
+
if self.jax_config.fuse_matmuls:
|
|
113
|
+
out = self._apply_fused(layer, x, bias)
|
|
114
|
+
else:
|
|
115
|
+
out = self._apply_split(layer, x, bias)
|
|
116
|
+
|
|
117
|
+
if out_sharding := self.jax_config.get_output_sharding(out):
|
|
118
|
+
out.shard_(NamedSharding(self.jax_config.mesh, out_sharding))
|
|
119
|
+
|
|
120
|
+
return out
|
|
121
|
+
|
|
122
|
+
def _apply_fused(self,
|
|
123
|
+
layer: torch.nn.Module,
|
|
124
|
+
x: torch.Tensor,
|
|
125
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
126
|
+
x_jax = jax_view(x)
|
|
127
|
+
weight_jax = jax_view(layer.weight)
|
|
128
|
+
|
|
129
|
+
outs = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
|
|
130
|
+
if bias is not None and not layer.skip_bias_add:
|
|
131
|
+
outs += bias.jax()
|
|
132
|
+
|
|
133
|
+
outs = slice_sharded_tensor_for_concatenation(
|
|
134
|
+
outs, self.jax_config.output_sizes, self.jax_config.n_shards)
|
|
135
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
136
|
+
return torch_view(out)
|
|
137
|
+
|
|
138
|
+
def _apply_split(self,
|
|
139
|
+
layer: torch.nn.Module,
|
|
140
|
+
x: torch.Tensor,
|
|
141
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
142
|
+
assert isinstance(layer.weight, torch.nn.ParameterList)
|
|
143
|
+
|
|
144
|
+
x_jax = x.jax()
|
|
145
|
+
outs = []
|
|
146
|
+
for i, weight in enumerate(layer.weight):
|
|
147
|
+
weight_jax = jax_view(weight)
|
|
148
|
+
|
|
149
|
+
out = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
|
|
150
|
+
if bias is not None and not layer.skip_bias_add:
|
|
151
|
+
out += jax_view(bias[i])
|
|
152
|
+
|
|
153
|
+
outs.append(out)
|
|
154
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
155
|
+
return torch_view(out)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
159
|
+
|
|
160
|
+
def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
|
|
161
|
+
super().__init__(moe)
|
|
162
|
+
self.mesh = mesh
|
|
163
|
+
|
|
164
|
+
def select_gemm_impl(
|
|
165
|
+
self,
|
|
166
|
+
prepare_finalize: FusedMoEPrepareAndFinalize,
|
|
167
|
+
moe: FusedMoEConfig,
|
|
168
|
+
layer: torch.nn.Module,
|
|
169
|
+
) -> FusedMoEPermuteExpertsUnpermute:
|
|
170
|
+
raise NotImplementedError(
|
|
171
|
+
"Selecting gemm implementation is currently not supported.")
|
|
172
|
+
|
|
173
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
174
|
+
assert isinstance(layer, FusedMoE)
|
|
175
|
+
|
|
176
|
+
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
177
|
+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
178
|
+
|
|
179
|
+
if layer.use_ep:
|
|
180
|
+
w13_weight = jax.device_put(
|
|
181
|
+
w13_weight,
|
|
182
|
+
Format(Layout((0, 1, 2)),
|
|
183
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
184
|
+
w2_weight = jax.device_put(
|
|
185
|
+
w2_weight,
|
|
186
|
+
Format(Layout((0, 1, 2)),
|
|
187
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
188
|
+
else:
|
|
189
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
190
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
191
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
192
|
+
n_shards = self.mesh.shape["model"]
|
|
193
|
+
assert intermediate_size % n_shards == 0
|
|
194
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
|
|
195
|
+
output_sizes,
|
|
196
|
+
n_shards,
|
|
197
|
+
dim=1)
|
|
198
|
+
w13_weight = jax.device_put(
|
|
199
|
+
w13_weight,
|
|
200
|
+
Format(Layout((0, 1, 2)),
|
|
201
|
+
NamedSharding(self.mesh, P(None, "model", None))))
|
|
202
|
+
w2_weight = jax.device_put(
|
|
203
|
+
w2_weight,
|
|
204
|
+
Format(Layout((0, 1, 2)),
|
|
205
|
+
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
206
|
+
w13_weight = Parameter(torch_view(w13_weight), requires_grad=False)
|
|
207
|
+
w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
208
|
+
|
|
209
|
+
layer.w13_weight = w13_weight
|
|
210
|
+
layer.w2_weight = w2_weight
|
|
211
|
+
|
|
212
|
+
def apply(
|
|
213
|
+
self,
|
|
214
|
+
layer: torch.nn.Module,
|
|
215
|
+
x: torch.Tensor,
|
|
216
|
+
router_logits: torch.Tensor,
|
|
217
|
+
top_k: int,
|
|
218
|
+
renormalize: bool,
|
|
219
|
+
use_grouped_topk: bool = False,
|
|
220
|
+
topk_group: Optional[int] = None,
|
|
221
|
+
num_expert_group: Optional[int] = None,
|
|
222
|
+
global_num_experts: int = -1,
|
|
223
|
+
expert_map: Optional[torch.Tensor] = None,
|
|
224
|
+
custom_routing_function: Optional[Callable] = None,
|
|
225
|
+
scoring_func: str = "softmax",
|
|
226
|
+
routed_scaling_factor: float = 1.0,
|
|
227
|
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
228
|
+
apply_router_weight_on_input: bool = False,
|
|
229
|
+
activation: str = "silu",
|
|
230
|
+
enable_eplb: bool = False,
|
|
231
|
+
expert_load_view: Optional[torch.Tensor] = None,
|
|
232
|
+
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
233
|
+
logical_replica_count: Optional[torch.Tensor] = None,
|
|
234
|
+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
235
|
+
assert isinstance(layer, FusedMoE)
|
|
236
|
+
if activation != "silu":
|
|
237
|
+
raise NotImplementedError(
|
|
238
|
+
"Only silu is supported for activation function.")
|
|
239
|
+
if scoring_func != "softmax":
|
|
240
|
+
raise NotImplementedError(
|
|
241
|
+
"Only softmax is supported for scoring_func")
|
|
242
|
+
|
|
243
|
+
_fused_moe_func = functools.partial(
|
|
244
|
+
jax.jit(jax_fused_moe_func_padded,
|
|
245
|
+
static_argnames=[
|
|
246
|
+
"topk", "global_num_experts", "renormalize",
|
|
247
|
+
"reduce_results", "mesh", "use_ep"
|
|
248
|
+
]),
|
|
249
|
+
topk=top_k,
|
|
250
|
+
global_num_experts=global_num_experts,
|
|
251
|
+
renormalize=renormalize,
|
|
252
|
+
reduce_results=layer.reduce_results,
|
|
253
|
+
mesh=self.mesh,
|
|
254
|
+
use_ep=layer.use_ep)
|
|
255
|
+
|
|
256
|
+
output = _fused_moe_func(
|
|
257
|
+
jax_view(x),
|
|
258
|
+
jax_view(layer.w13_weight),
|
|
259
|
+
jax_view(layer.w2_weight),
|
|
260
|
+
jax_view(router_logits),
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return torch_view(output)
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import torch
|
|
3
|
+
import torchax
|
|
4
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
5
|
+
from torch.nn import Parameter
|
|
6
|
+
from torch.utils import _pytree as pytree
|
|
7
|
+
from torchax.interop import jax_view, torch_view
|
|
8
|
+
from torchax.ops.mappings import t2j
|
|
9
|
+
from vllm.lora.layers import (MergedColumnParallelLinearWithLoRA,
|
|
10
|
+
MergedQKVParallelLinearWithLoRA,
|
|
11
|
+
RowParallelLinearWithLoRA)
|
|
12
|
+
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
|
13
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
14
|
+
ParallelLMHead, VocabParallelEmbedding)
|
|
15
|
+
|
|
16
|
+
from tpu_inference.logger import init_logger
|
|
17
|
+
|
|
18
|
+
P = PartitionSpec
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def shard_model_to_tpu(model: torch.nn.Module,
|
|
24
|
+
mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
|
|
25
|
+
"""
|
|
26
|
+
Shard the model weights and move them to TPU.
|
|
27
|
+
|
|
28
|
+
At the same time, also turn the weight tensors into torchax tensors so that
|
|
29
|
+
jax code can interop with it and the overall program can be traced and
|
|
30
|
+
compiled in XLA.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model: A PyTorch model whose weights are on CPU main memory.
|
|
34
|
+
mesh: JAX mesh object for sharding.
|
|
35
|
+
Returns:
|
|
36
|
+
Dictionary of parameters and buffers that will be used as arguments of
|
|
37
|
+
torch.func.functional_call
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
41
|
+
_shard_module_to_tpu(model, mesh)
|
|
42
|
+
|
|
43
|
+
params, buffers = _extract_all_params_buffers(model)
|
|
44
|
+
|
|
45
|
+
# For other weight tensors, repliate them on all the TPU chips.
|
|
46
|
+
params, buffers = pytree.tree_map_only(
|
|
47
|
+
_tensor_is_in_cpu,
|
|
48
|
+
lambda tensor: _shard_tensor_to_tpu_replicated(tensor, mesh),
|
|
49
|
+
(params, buffers))
|
|
50
|
+
|
|
51
|
+
return {**params, **buffers}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _extract_all_params_buffers(model: torch.nn.Module):
|
|
55
|
+
return dict(model.named_parameters()), dict(model.named_buffers())
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
|
|
59
|
+
# Check if a tensor haven't been converted to torchax tensor.
|
|
60
|
+
if not isinstance(tensor, torchax.tensor.Tensor):
|
|
61
|
+
return True
|
|
62
|
+
# Check if torchax tensor is still in CPU.
|
|
63
|
+
return tensor.jax_device == jax.devices('cpu')[0]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _convert_to_torchax_and_shard(tensor: torch.Tensor,
|
|
67
|
+
sharding: NamedSharding) -> torch.Tensor:
|
|
68
|
+
if isinstance(tensor, torchax.tensor.Tensor):
|
|
69
|
+
tensor = jax_view(tensor)
|
|
70
|
+
else:
|
|
71
|
+
tensor = t2j(tensor)
|
|
72
|
+
return torch_view(jax.device_put(tensor, sharding))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
|
|
76
|
+
mesh: Mesh) -> torchax.tensor.Tensor:
|
|
77
|
+
return _convert_to_torchax_and_shard(tensor, NamedSharding(mesh, P()))
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
|
|
81
|
+
mesh: Mesh) -> None:
|
|
82
|
+
weight = _convert_to_torchax_and_shard(
|
|
83
|
+
layer.weight, NamedSharding(mesh, P('model', None)))
|
|
84
|
+
layer.weight = Parameter(weight, requires_grad=False)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
|
|
88
|
+
# TODO(qihqi): currently this is not handling case of tie_word_weights=True.
|
|
89
|
+
# if that config is set, then we should not create new weights but reuse the
|
|
90
|
+
# weight from VocabParallelEmbedding
|
|
91
|
+
weight = _convert_to_torchax_and_shard(
|
|
92
|
+
layer.weight, NamedSharding(mesh, P('model', None)))
|
|
93
|
+
layer.weight = Parameter(weight, requires_grad=False)
|
|
94
|
+
if layer.bias is not None:
|
|
95
|
+
bias = _convert_to_torchax_and_shard(layer.bias,
|
|
96
|
+
NamedSharding(mesh, P('model')))
|
|
97
|
+
layer.bias = Parameter(bias, requires_grad=False)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _shard_base_linear_lora(layer: BaseLinearLayerWithLoRA,
|
|
101
|
+
mesh: Mesh) -> None:
|
|
102
|
+
# NOTE: lora_a_stacked[i] has shape [max_loras, 1, num_out, num_in]
|
|
103
|
+
sharded_lora_a_tpu = torch.nn.ParameterList()
|
|
104
|
+
sharded_lora_b_tpu = torch.nn.ParameterList()
|
|
105
|
+
|
|
106
|
+
for i in range(layer.n_slices):
|
|
107
|
+
sharded_lora_a_tpu.append(
|
|
108
|
+
_shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh))
|
|
109
|
+
sharded_lora_b_tpu.append(
|
|
110
|
+
_shard_tensor_to_tpu_replicated(layer.lora_b_stacked[i], mesh))
|
|
111
|
+
|
|
112
|
+
layer.lora_a_stacked = sharded_lora_a_tpu
|
|
113
|
+
layer.lora_b_stacked = sharded_lora_b_tpu
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# TODO: Add custom sharding logic for following lora layers
|
|
117
|
+
def _shard_column_parallel_linear_lora(
|
|
118
|
+
layer: MergedColumnParallelLinearWithLoRA, mesh: Mesh) -> None:
|
|
119
|
+
_shard_base_linear_lora(layer, mesh)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _shard_qkv_parallel_linear_lora(layer: MergedQKVParallelLinearWithLoRA,
|
|
123
|
+
mesh: Mesh) -> None:
|
|
124
|
+
_shard_base_linear_lora(layer, mesh)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
|
|
128
|
+
mesh: Mesh) -> None:
|
|
129
|
+
_shard_base_linear_lora(layer, mesh)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
# NOTE: Ordering is important as it calls first matched type of a given module
|
|
133
|
+
MODULE_TYPE_TO_SHARDING_FUNC = [
|
|
134
|
+
# Shard embedding layers
|
|
135
|
+
(ParallelLMHead, _shard_lm_head),
|
|
136
|
+
(VocabParallelEmbedding, _shard_vocab_parallel_embedding),
|
|
137
|
+
# Shard LoRA layers
|
|
138
|
+
(MergedColumnParallelLinearWithLoRA, _shard_column_parallel_linear_lora),
|
|
139
|
+
(MergedQKVParallelLinearWithLoRA, _shard_qkv_parallel_linear_lora),
|
|
140
|
+
(RowParallelLinearWithLoRA, _shard_row_parallel_linear_lora),
|
|
141
|
+
(BaseLinearLayerWithLoRA, _shard_base_linear_lora),
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
|
|
146
|
+
for path, module in model.named_modules():
|
|
147
|
+
for module_type, sharding_func in MODULE_TYPE_TO_SHARDING_FUNC:
|
|
148
|
+
if isinstance(module, module_type):
|
|
149
|
+
logger.debug("shard %s with %s", path, sharding_func)
|
|
150
|
+
sharding_func(module, mesh)
|
|
151
|
+
break
|
|
File without changes
|