tpu-inference 0.11.1rc1__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
- {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
- {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
7
|
+
from torchax.interop import jax_view, torch_view
|
|
8
|
+
from vllm.logger import init_logger
|
|
9
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
10
|
+
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
11
|
+
from vllm.model_executor.layers.quantization import \
|
|
12
|
+
register_quantization_config
|
|
13
|
+
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
|
14
|
+
AWQLinearMethod,
|
|
15
|
+
is_layer_skipped_awq)
|
|
16
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
17
|
+
QuantizeMethodBase
|
|
18
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
19
|
+
unpack_quantized_values_into_int32
|
|
20
|
+
from vllm.scalar_type import scalar_types
|
|
21
|
+
|
|
22
|
+
from tpu_inference.layers.vllm.linear_common import (
|
|
23
|
+
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
24
|
+
from tpu_inference.layers.vllm.quantization.common import (
|
|
25
|
+
JaxCommonConfig, JaxCommonLinearConfig)
|
|
26
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
27
|
+
VllmUnquantizedLinearMethod
|
|
28
|
+
|
|
29
|
+
P = PartitionSpec
|
|
30
|
+
logger = init_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@register_quantization_config("jax-awq")
|
|
34
|
+
class VllmAWQConfig(AWQConfig, JaxCommonConfig):
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def get_name(cls) -> str:
|
|
38
|
+
return "jax-awq"
|
|
39
|
+
|
|
40
|
+
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
41
|
+
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
|
|
42
|
+
# bfloat16 is signifcantly preferred over foat16. This might lead to
|
|
43
|
+
# some numeric output change.
|
|
44
|
+
return [torch.bfloat16]
|
|
45
|
+
|
|
46
|
+
def get_quant_method(
|
|
47
|
+
self, layer: torch.nn.Module, prefix: str
|
|
48
|
+
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
|
|
49
|
+
if isinstance(layer, LinearBase):
|
|
50
|
+
linear_config = self.get_linear_config(layer)
|
|
51
|
+
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
|
52
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
53
|
+
return VllmAWQLinearMethod(self, linear_config)
|
|
54
|
+
elif isinstance(layer, FusedMoE):
|
|
55
|
+
raise NotImplementedError(
|
|
56
|
+
"AWQ FusedMoE is currently not supported in torchax-jax")
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class VllmAWQLinearMethod(AWQLinearMethod):
|
|
61
|
+
|
|
62
|
+
def __init__(self, quant_config: VllmAWQConfig,
|
|
63
|
+
jax_config: JaxCommonLinearConfig):
|
|
64
|
+
super().__init__(quant_config)
|
|
65
|
+
self.jax_config = jax_config
|
|
66
|
+
|
|
67
|
+
out_sharding, in_sharding = self.jax_config.weight_sharding[:]
|
|
68
|
+
self.jax_config.weight_sharding = P(in_sharding, None, out_sharding)
|
|
69
|
+
self.jax_config.scale_sharding = P(in_sharding, out_sharding)
|
|
70
|
+
|
|
71
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
72
|
+
qweight = layer.qweight
|
|
73
|
+
qweight = unpack_awq_weight(qweight, qweight.packed_dim)
|
|
74
|
+
|
|
75
|
+
group_size = self.quant_config.group_size
|
|
76
|
+
# Reshape so that each qweight[i] were quantized with same scales[i].
|
|
77
|
+
qweight = qweight.reshape((-1, group_size, layer.output_size))
|
|
78
|
+
qweight = torch_to_jax_param(qweight,
|
|
79
|
+
NamedSharding(
|
|
80
|
+
self.jax_config.mesh,
|
|
81
|
+
self.jax_config.weight_sharding),
|
|
82
|
+
self.jax_config.output_sizes,
|
|
83
|
+
self.jax_config.n_shards,
|
|
84
|
+
self.jax_config.fuse_matmuls,
|
|
85
|
+
dim=2,
|
|
86
|
+
jax_dtype=jnp.uint4)
|
|
87
|
+
delattr(layer, "qweight")
|
|
88
|
+
layer.qweight = qweight
|
|
89
|
+
|
|
90
|
+
qzeros = layer.qzeros
|
|
91
|
+
qzeros = unpack_awq_weight(qzeros, qzeros.packed_dim)
|
|
92
|
+
qzeros = torch_to_jax_param(qzeros,
|
|
93
|
+
NamedSharding(
|
|
94
|
+
self.jax_config.mesh,
|
|
95
|
+
self.jax_config.scale_sharding),
|
|
96
|
+
self.jax_config.output_sizes,
|
|
97
|
+
self.jax_config.n_shards,
|
|
98
|
+
self.jax_config.fuse_matmuls,
|
|
99
|
+
dim=1,
|
|
100
|
+
jax_dtype=jnp.uint4)
|
|
101
|
+
delattr(layer, "qzeros")
|
|
102
|
+
layer.qzeros = qzeros
|
|
103
|
+
|
|
104
|
+
scales = torch_to_jax_param(layer.scales,
|
|
105
|
+
NamedSharding(
|
|
106
|
+
self.jax_config.mesh,
|
|
107
|
+
self.jax_config.scale_sharding),
|
|
108
|
+
self.jax_config.output_sizes,
|
|
109
|
+
self.jax_config.n_shards,
|
|
110
|
+
self.jax_config.fuse_matmuls,
|
|
111
|
+
dim=1)
|
|
112
|
+
delattr(layer, "scales")
|
|
113
|
+
layer.scales = scales
|
|
114
|
+
|
|
115
|
+
if layer.bias is not None and not layer.skip_bias_add:
|
|
116
|
+
if layer.return_bias:
|
|
117
|
+
logger.warning_once("Bias might return incorrect value.")
|
|
118
|
+
|
|
119
|
+
bias = torch_to_jax_param(
|
|
120
|
+
layer.bias,
|
|
121
|
+
NamedSharding(self.jax_config.mesh,
|
|
122
|
+
self.jax_config.bias_sharding),
|
|
123
|
+
self.jax_config.output_sizes,
|
|
124
|
+
self.jax_config.n_shards,
|
|
125
|
+
self.jax_config.fuse_matmuls,
|
|
126
|
+
)
|
|
127
|
+
delattr(layer, "bias")
|
|
128
|
+
layer.bias = bias
|
|
129
|
+
|
|
130
|
+
def apply(self,
|
|
131
|
+
layer: torch.nn.Module,
|
|
132
|
+
x: torch.Tensor,
|
|
133
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
134
|
+
|
|
135
|
+
with jax.named_scope(layer._get_name()):
|
|
136
|
+
if self.jax_config.fuse_matmuls:
|
|
137
|
+
out = self._apply_fused(layer, x, bias)
|
|
138
|
+
else:
|
|
139
|
+
out = self._apply_split(layer, x, bias)
|
|
140
|
+
|
|
141
|
+
return out
|
|
142
|
+
|
|
143
|
+
def _apply_fused(self,
|
|
144
|
+
layer: torch.nn.Module,
|
|
145
|
+
x: torch.Tensor,
|
|
146
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
147
|
+
x_jax = jax_view(x)
|
|
148
|
+
|
|
149
|
+
qweight = jax_view(layer.qweight)
|
|
150
|
+
qzeros = jnp.expand_dims(jax_view(layer.qzeros), 1)
|
|
151
|
+
scales = jnp.expand_dims(jax_view(layer.scales), 1)
|
|
152
|
+
|
|
153
|
+
qweight = qweight.astype(jnp.int8)
|
|
154
|
+
qzeros = qzeros.astype(jnp.int8)
|
|
155
|
+
|
|
156
|
+
weight = (qweight - qzeros) * scales
|
|
157
|
+
weight = weight.reshape((-1, weight.shape[-1]))
|
|
158
|
+
outs = jnp.einsum("bd,df->bf", x_jax, weight)
|
|
159
|
+
|
|
160
|
+
if bias is not None and not layer.skip_bias_add:
|
|
161
|
+
outs += bias.jax()
|
|
162
|
+
|
|
163
|
+
outs = slice_sharded_tensor_for_concatenation(
|
|
164
|
+
outs, self.jax_config.output_sizes, self.jax_config.n_shards)
|
|
165
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
166
|
+
return torch_view(out)
|
|
167
|
+
|
|
168
|
+
def _apply_split(self,
|
|
169
|
+
layer: torch.nn.Module,
|
|
170
|
+
x: torch.Tensor,
|
|
171
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
172
|
+
assert isinstance(layer.qweight, torch.nn.ParameterList)
|
|
173
|
+
|
|
174
|
+
x_jax = jax_view(x)
|
|
175
|
+
params = zip(layer.qweight, layer.qzeros, layer.scales)
|
|
176
|
+
outs = []
|
|
177
|
+
for i, (qweight, qzeros, scales) in enumerate(params):
|
|
178
|
+
qweight = jax_view(qweight)
|
|
179
|
+
scales = jnp.expand_dims(jax_view(scales), 1)
|
|
180
|
+
qzeros = jnp.expand_dims(jax_view(qzeros), 1)
|
|
181
|
+
|
|
182
|
+
qweight = qweight.astype(jnp.int8)
|
|
183
|
+
qzeros = qzeros.astype(jnp.int8)
|
|
184
|
+
|
|
185
|
+
weight = (qweight - qzeros) * scales
|
|
186
|
+
weight = weight.reshape((-1, weight.shape[-1]))
|
|
187
|
+
out = jnp.einsum("bd,df->bf", x_jax, weight)
|
|
188
|
+
|
|
189
|
+
if bias is not None and not layer.skip_bias_add:
|
|
190
|
+
out += jax_view(bias[i])
|
|
191
|
+
|
|
192
|
+
outs.append(out)
|
|
193
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
194
|
+
return torch_view(out)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def unpack_awq_weight(weight: torch.Tensor, packed_dim: int):
|
|
198
|
+
weight = unpack_quantized_values_into_int32(weight, scalar_types.uint4,
|
|
199
|
+
packed_dim)
|
|
200
|
+
|
|
201
|
+
# AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
|
|
202
|
+
# Following list maps the order used by AWQ into an ascending order.
|
|
203
|
+
reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
|
|
204
|
+
|
|
205
|
+
orig_shape = weight.shape
|
|
206
|
+
weight = weight.reshape(orig_shape[:-1] + (-1, 8))
|
|
207
|
+
return weight[..., reverse_awq_order].reshape(orig_shape)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import torchax
|
|
2
|
+
from jax.sharding import Mesh, PartitionSpec
|
|
3
|
+
from vllm.config import VllmConfig
|
|
4
|
+
from vllm.logger import init_logger
|
|
5
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEConfig
|
|
6
|
+
# yapf: disable
|
|
7
|
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
8
|
+
LinearBase,
|
|
9
|
+
MergedColumnParallelLinear,
|
|
10
|
+
QKVParallelLinear,
|
|
11
|
+
ReplicatedLinear,
|
|
12
|
+
RowParallelLinear)
|
|
13
|
+
|
|
14
|
+
from tpu_inference.layers.vllm.linear_common import \
|
|
15
|
+
get_model_matmul_fusion_assignment
|
|
16
|
+
from tpu_inference.utils import TPU_SECOND_LAST_MINOR
|
|
17
|
+
|
|
18
|
+
# yapf: enable
|
|
19
|
+
|
|
20
|
+
P = PartitionSpec
|
|
21
|
+
|
|
22
|
+
logger = init_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class JaxCommonLinearConfig:
|
|
26
|
+
|
|
27
|
+
def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
|
|
28
|
+
assert isinstance(layer, LinearBase)
|
|
29
|
+
|
|
30
|
+
self.mesh = mesh
|
|
31
|
+
self.output_sizes = [layer.output_size]
|
|
32
|
+
self.weight_sharding = P(None, None)
|
|
33
|
+
self.fuse_matmuls = True
|
|
34
|
+
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
|
35
|
+
self.input_sharding = None
|
|
36
|
+
self.output_sharding = None
|
|
37
|
+
|
|
38
|
+
if isinstance(layer, RowParallelLinear):
|
|
39
|
+
self.weight_sharding = P(None, "model")
|
|
40
|
+
if self.enable_sequence_parallelism:
|
|
41
|
+
self.output_sharding = P("model", None)
|
|
42
|
+
elif isinstance(layer, ColumnParallelLinear):
|
|
43
|
+
self.weight_sharding = P("model", None)
|
|
44
|
+
if self.enable_sequence_parallelism:
|
|
45
|
+
self.input_sharding = P("model", None)
|
|
46
|
+
|
|
47
|
+
if isinstance(layer, MergedColumnParallelLinear) or isinstance(
|
|
48
|
+
layer, QKVParallelLinear):
|
|
49
|
+
self.output_sizes = layer.output_sizes
|
|
50
|
+
|
|
51
|
+
self.fuse_matmuls = get_model_matmul_fusion_assignment(
|
|
52
|
+
vllm_config.model_config.model,
|
|
53
|
+
vllm_config.scheduler_config.max_num_batched_tokens,
|
|
54
|
+
vllm_config.parallel_config.tensor_parallel_size,
|
|
55
|
+
layer._get_name())
|
|
56
|
+
elif isinstance(layer, ReplicatedLinear):
|
|
57
|
+
self.weight_sharding = P(None, None)
|
|
58
|
+
else:
|
|
59
|
+
logger.warning(
|
|
60
|
+
"Unsupported linear layer type of %s. Can potentially yield "
|
|
61
|
+
" bad performance.", type(layer))
|
|
62
|
+
|
|
63
|
+
self.bias_sharding = P(self.weight_sharding[0])
|
|
64
|
+
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
65
|
+
|
|
66
|
+
def get_input_sharding(self, x: torchax.tensor.Tensor):
|
|
67
|
+
if self.enable_sequence_parallelism:
|
|
68
|
+
token_num = x.shape[0]
|
|
69
|
+
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
70
|
+
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
|
|
71
|
+
return self.input_sharding
|
|
72
|
+
else:
|
|
73
|
+
return None
|
|
74
|
+
return self.input_sharding
|
|
75
|
+
|
|
76
|
+
def get_output_sharding(self, x: torchax.tensor.Tensor):
|
|
77
|
+
if self.enable_sequence_parallelism:
|
|
78
|
+
token_num = x.shape[0]
|
|
79
|
+
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
80
|
+
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
|
|
81
|
+
return self.output_sharding
|
|
82
|
+
else:
|
|
83
|
+
return None
|
|
84
|
+
return self.output_sharding
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class JaxCommonConfig:
|
|
88
|
+
vllm_config: VllmConfig
|
|
89
|
+
mesh: Mesh
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def set_configs(cls, vllm_config: VllmConfig, mesh: Mesh):
|
|
93
|
+
cls.vllm_config = vllm_config
|
|
94
|
+
cls.mesh = mesh
|
|
95
|
+
|
|
96
|
+
def get_linear_config(self, layer: LinearBase) -> JaxCommonLinearConfig:
|
|
97
|
+
assert isinstance(layer, LinearBase)
|
|
98
|
+
return JaxCommonLinearConfig(self.vllm_config, self.mesh, layer)
|
|
99
|
+
|
|
100
|
+
def get_moe_config(self, layer: FusedMoE) -> FusedMoEConfig:
|
|
101
|
+
assert isinstance(layer, FusedMoE)
|
|
102
|
+
moe_config = layer.moe_config
|
|
103
|
+
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
|
104
|
+
moe_config.moe_parallel_config.use_ep = use_ep
|
|
105
|
+
return moe_config
|
|
File without changes
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from jax.sharding import PartitionSpec
|
|
5
|
+
from vllm.attention.layer import Attention
|
|
6
|
+
from vllm.logger import init_logger
|
|
7
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
8
|
+
from vllm.model_executor.layers.linear import LinearBase
|
|
9
|
+
from vllm.model_executor.layers.quantization import \
|
|
10
|
+
register_quantization_config
|
|
11
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
12
|
+
QuantizeMethodBase # noqa: E501
|
|
13
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
|
14
|
+
CompressedTensorsConfig, CompressedTensorsKVCacheMethod,
|
|
15
|
+
CompressedTensorsLinearMethod, CompressedTensorsScheme)
|
|
16
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
17
|
+
find_matched_target, is_activation_quantization_format,
|
|
18
|
+
should_ignore_layer)
|
|
19
|
+
|
|
20
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
21
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
|
|
22
|
+
VllmCompressedTensorsW8A8Fp8
|
|
23
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
24
|
+
VllmCompressedTensorsW8A8Int8
|
|
25
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
26
|
+
VllmUnquantizedConfig
|
|
27
|
+
|
|
28
|
+
P = PartitionSpec
|
|
29
|
+
logger = init_logger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@register_quantization_config("jax-compressed-tensors")
|
|
33
|
+
class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def get_name(cls) -> str:
|
|
37
|
+
return "jax-compressed-tensors"
|
|
38
|
+
|
|
39
|
+
def get_scheme(self,
|
|
40
|
+
layer: torch.nn.Module,
|
|
41
|
+
layer_name: Optional[str] = None
|
|
42
|
+
) -> Optional["CompressedTensorsScheme"]:
|
|
43
|
+
"""
|
|
44
|
+
compressed-tensors supports non uniform in the following way:
|
|
45
|
+
|
|
46
|
+
targets of config_groups: There can be N config_groups which each
|
|
47
|
+
have a quantization scheme. Each config_group has a list of targets
|
|
48
|
+
which can be a full layer_name, a regex for a layer_name, or
|
|
49
|
+
an nn.Module name.
|
|
50
|
+
|
|
51
|
+
Detect whether a layer_name is found in any target and
|
|
52
|
+
use the quantization scheme corresponding to the matched target
|
|
53
|
+
to select the CompressedTensorsScheme used for inference.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
# Will be empty for models with only sparsity
|
|
57
|
+
weight_quant = input_quant = None
|
|
58
|
+
if self.target_scheme_map:
|
|
59
|
+
matched_target = find_matched_target(
|
|
60
|
+
layer_name=layer_name,
|
|
61
|
+
module=layer,
|
|
62
|
+
targets=self.target_scheme_map.keys(),
|
|
63
|
+
fused_mapping=self.packed_modules_mapping)
|
|
64
|
+
|
|
65
|
+
scheme_dict = self.target_scheme_map[matched_target]
|
|
66
|
+
weight_quant = scheme_dict.get("weights")
|
|
67
|
+
input_quant = scheme_dict.get("input_activations")
|
|
68
|
+
format = scheme_dict.get("format")
|
|
69
|
+
|
|
70
|
+
if weight_quant is None:
|
|
71
|
+
logger.warning_once("Acceleration for non-quantized schemes is "
|
|
72
|
+
"not supported by Compressed Tensors. "
|
|
73
|
+
"Falling back to UnquantizedLinearMethod")
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
# TODO(kyuyeunk): Add support for different act_quant_format
|
|
77
|
+
act_quant_format = is_activation_quantization_format( # noqa: F841
|
|
78
|
+
format
|
|
79
|
+
) if format is not None else is_activation_quantization_format(
|
|
80
|
+
self.quant_format)
|
|
81
|
+
|
|
82
|
+
linear_config = self.get_linear_config(layer)
|
|
83
|
+
if self._is_fp8_w8a8(weight_quant, input_quant):
|
|
84
|
+
is_static_input_scheme = input_quant and not input_quant.dynamic
|
|
85
|
+
return VllmCompressedTensorsW8A8Fp8(
|
|
86
|
+
weight_quant=weight_quant,
|
|
87
|
+
is_static_input_scheme=is_static_input_scheme,
|
|
88
|
+
jax_config=linear_config,
|
|
89
|
+
)
|
|
90
|
+
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
|
91
|
+
return VllmCompressedTensorsW8A8Int8(
|
|
92
|
+
strategy=weight_quant.strategy,
|
|
93
|
+
is_static_input_scheme=False,
|
|
94
|
+
input_symmetric=input_quant.symmetric,
|
|
95
|
+
jax_config=linear_config,
|
|
96
|
+
)
|
|
97
|
+
raise NotImplementedError(
|
|
98
|
+
"No compressed-tensors compatible scheme was found.")
|
|
99
|
+
|
|
100
|
+
def get_quant_method(
|
|
101
|
+
self,
|
|
102
|
+
layer: torch.nn.Module,
|
|
103
|
+
prefix: str,
|
|
104
|
+
) -> Optional[QuantizeMethodBase]:
|
|
105
|
+
if should_ignore_layer(prefix,
|
|
106
|
+
ignore=self.ignore,
|
|
107
|
+
fused_mapping=self.packed_modules_mapping):
|
|
108
|
+
return VllmUnquantizedConfig.get_quant_method(self, layer, prefix)
|
|
109
|
+
if isinstance(layer, LinearBase):
|
|
110
|
+
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
|
111
|
+
if scheme is None:
|
|
112
|
+
return VllmUnquantizedConfig.get_quant_method(
|
|
113
|
+
self, layer, prefix)
|
|
114
|
+
layer.scheme = scheme
|
|
115
|
+
return CompressedTensorsLinearMethod(self)
|
|
116
|
+
if isinstance(layer, FusedMoE):
|
|
117
|
+
raise NotImplementedError(
|
|
118
|
+
"FusedMoE quantization is currently not supported.")
|
|
119
|
+
if isinstance(layer, Attention):
|
|
120
|
+
return CompressedTensorsKVCacheMethod(self)
|
|
121
|
+
return None
|
|
File without changes
|
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
from compressed_tensors.quantization import (QuantizationArgs,
|
|
7
|
+
QuantizationStrategy)
|
|
8
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
|
+
from torchax.interop import jax_view, torch_view
|
|
10
|
+
from torchax.ops.mappings import t2j
|
|
11
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
|
|
12
|
+
CompressedTensorsW8A8Fp8
|
|
13
|
+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
|
|
14
|
+
per_tensor_dequantize
|
|
15
|
+
|
|
16
|
+
from tpu_inference.layers.vllm.linear_common import (
|
|
17
|
+
sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
|
|
18
|
+
torch_to_jax_param)
|
|
19
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
|
|
20
|
+
|
|
21
|
+
P = PartitionSpec
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def requantize_with_max_scale(
|
|
25
|
+
weight: torch.Tensor, weight_scale: torch.Tensor,
|
|
26
|
+
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
27
|
+
dtype = weight.dtype
|
|
28
|
+
dtype_info = torch.finfo(dtype)
|
|
29
|
+
maxval = float(dtype_info.max)
|
|
30
|
+
minval = float(dtype_info.min)
|
|
31
|
+
|
|
32
|
+
max_w_scale = weight_scale.max()
|
|
33
|
+
|
|
34
|
+
unfused_module_in_checkpoint = (weight_scale[-1]
|
|
35
|
+
> torch.finfo(torch.float8_e4m3fn).min)
|
|
36
|
+
|
|
37
|
+
# If unfused checkpoint, need requanize with the single scale.
|
|
38
|
+
if unfused_module_in_checkpoint:
|
|
39
|
+
start = 0
|
|
40
|
+
for idx, logical_width in enumerate(logical_widths):
|
|
41
|
+
# Skip any component with zero width.
|
|
42
|
+
if logical_width == 0:
|
|
43
|
+
continue
|
|
44
|
+
end = start + logical_width
|
|
45
|
+
weight_dq = per_tensor_dequantize(weight[start:end, :],
|
|
46
|
+
weight_scale[idx])
|
|
47
|
+
weight_q = weight_dq / max_w_scale
|
|
48
|
+
weight[start:end, :] = weight_q.clamp(minval, maxval).to(dtype)
|
|
49
|
+
start = end
|
|
50
|
+
|
|
51
|
+
return max_w_scale, weight
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
weight_quant: QuantizationArgs,
|
|
59
|
+
is_static_input_scheme: bool,
|
|
60
|
+
jax_config: JaxCommonLinearConfig,
|
|
61
|
+
):
|
|
62
|
+
super().__init__(weight_quant, is_static_input_scheme)
|
|
63
|
+
|
|
64
|
+
self.jax_config = jax_config
|
|
65
|
+
|
|
66
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
67
|
+
weight = layer.weight
|
|
68
|
+
weight_scale = layer.weight_scale
|
|
69
|
+
|
|
70
|
+
if self.is_static_input_scheme:
|
|
71
|
+
# In static quant, all input_scales share the same value.
|
|
72
|
+
assert layer.input_scale.min() == layer.input_scale.max()
|
|
73
|
+
input_scale_first = layer.input_scale[0]
|
|
74
|
+
|
|
75
|
+
input_scale = jax.device_put(
|
|
76
|
+
t2j(input_scale_first, use_dlpack=False),
|
|
77
|
+
NamedSharding(self.jax_config.mesh, P()))
|
|
78
|
+
input_scale = torch.nn.Parameter(torch_view(input_scale),
|
|
79
|
+
requires_grad=False)
|
|
80
|
+
delattr(layer, "input_scale")
|
|
81
|
+
layer.input_scale = input_scale
|
|
82
|
+
|
|
83
|
+
# TODO(kyuyeunk): Investigate performance gain from merging scales.
|
|
84
|
+
# By merging input and weight scales, we reduce the number of muls
|
|
85
|
+
# required for dequantization from 2 (for each scales) to 1.
|
|
86
|
+
# weight_scale *= input_scale_first
|
|
87
|
+
|
|
88
|
+
if self.strategy == QuantizationStrategy.TENSOR:
|
|
89
|
+
weight_scale, weight = requantize_with_max_scale(
|
|
90
|
+
weight, weight_scale, self.jax_config.output_sizes)
|
|
91
|
+
weight_scale = jax.device_put(
|
|
92
|
+
t2j(weight_scale, use_dlpack=False),
|
|
93
|
+
NamedSharding(self.jax_config.mesh, P()))
|
|
94
|
+
weight_scale = torch.nn.Parameter(torch_view(weight_scale),
|
|
95
|
+
requires_grad=False)
|
|
96
|
+
else:
|
|
97
|
+
weight_scale = weight_scale.squeeze(-1)
|
|
98
|
+
weight_scale = torch_to_jax_param(
|
|
99
|
+
weight_scale,
|
|
100
|
+
NamedSharding(self.jax_config.mesh,
|
|
101
|
+
self.jax_config.bias_sharding),
|
|
102
|
+
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
103
|
+
self.jax_config.fuse_matmuls)
|
|
104
|
+
delattr(layer, "weight_scale")
|
|
105
|
+
layer.weight_scale = weight_scale
|
|
106
|
+
|
|
107
|
+
weight = torch_to_jax_param(
|
|
108
|
+
layer.weight,
|
|
109
|
+
NamedSharding(self.jax_config.mesh,
|
|
110
|
+
self.jax_config.weight_sharding),
|
|
111
|
+
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
112
|
+
self.jax_config.fuse_matmuls)
|
|
113
|
+
delattr(layer, "weight")
|
|
114
|
+
layer.weight = weight
|
|
115
|
+
|
|
116
|
+
if layer.bias is not None:
|
|
117
|
+
bias = torch_to_jax_param(
|
|
118
|
+
layer.bias,
|
|
119
|
+
NamedSharding(self.jax_config.mesh,
|
|
120
|
+
self.jax_config.bias_sharding),
|
|
121
|
+
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
122
|
+
self.jax_config.fuse_matmuls)
|
|
123
|
+
delattr(layer, "bias")
|
|
124
|
+
layer.bias = bias
|
|
125
|
+
|
|
126
|
+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
127
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
128
|
+
with jax.named_scope(layer._get_name()):
|
|
129
|
+
if self.jax_config.fuse_matmuls:
|
|
130
|
+
return self._apply_fused(layer, x, bias)
|
|
131
|
+
else:
|
|
132
|
+
return self._apply_split(layer, x, bias)
|
|
133
|
+
|
|
134
|
+
def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
135
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
136
|
+
x_jax = jax_view(x)
|
|
137
|
+
weight_jax = jax_view(layer.weight)
|
|
138
|
+
weight_scale_jax = jax_view(layer.weight_scale)
|
|
139
|
+
|
|
140
|
+
if self.is_static_input_scheme:
|
|
141
|
+
# TODO(kyuyeunk): Add kernel support for static quant
|
|
142
|
+
input_scale = jax_view(layer.input_scale)
|
|
143
|
+
dtype_info = jnp.finfo(weight_jax.dtype)
|
|
144
|
+
maxval = float(dtype_info.max)
|
|
145
|
+
minval = float(dtype_info.min)
|
|
146
|
+
x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
|
|
147
|
+
maxval).astype(weight_jax.dtype)
|
|
148
|
+
|
|
149
|
+
outs = jax.lax.dot_general(
|
|
150
|
+
x_q,
|
|
151
|
+
weight_jax,
|
|
152
|
+
(((1, ), (1, )), ((), ())),
|
|
153
|
+
preferred_element_type=jnp.float32,
|
|
154
|
+
)
|
|
155
|
+
outs *= weight_scale_jax
|
|
156
|
+
outs = outs.astype(x_jax.dtype)
|
|
157
|
+
else:
|
|
158
|
+
outs = sharded_quantized_matmul(x_jax, weight_jax,
|
|
159
|
+
weight_scale_jax,
|
|
160
|
+
self.jax_config.mesh,
|
|
161
|
+
self.jax_config.weight_sharding)
|
|
162
|
+
|
|
163
|
+
if bias is not None and not layer.skip_bias_add:
|
|
164
|
+
outs += jax_view(bias)
|
|
165
|
+
outs = slice_sharded_tensor_for_concatenation(
|
|
166
|
+
outs, self.jax_config.output_sizes, self.jax_config.n_shards)
|
|
167
|
+
return torch_view(jnp.concatenate(outs, axis=-1))
|
|
168
|
+
|
|
169
|
+
def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
170
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
171
|
+
assert isinstance(layer.weight, torch.nn.ParameterList)
|
|
172
|
+
|
|
173
|
+
x_jax = jax_view(x)
|
|
174
|
+
outs = []
|
|
175
|
+
for i, (weight, weight_scale) in enumerate(
|
|
176
|
+
zip(layer.weight, layer.weight_scale)):
|
|
177
|
+
weight_jax = jax_view(weight)
|
|
178
|
+
weight_scale_jax = jax_view(weight_scale)
|
|
179
|
+
|
|
180
|
+
if self.is_static_input_scheme:
|
|
181
|
+
# TODO(kyuyeunk): Add kernel support for static quant
|
|
182
|
+
input_scale = jax_view(layer.input_scale)
|
|
183
|
+
dtype_info = jnp.finfo(weight_jax.dtype)
|
|
184
|
+
maxval = float(dtype_info.max)
|
|
185
|
+
minval = float(dtype_info.min)
|
|
186
|
+
x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
|
|
187
|
+
maxval).astype(weight_jax.dtype)
|
|
188
|
+
|
|
189
|
+
out = jax.lax.dot_general(
|
|
190
|
+
x_q,
|
|
191
|
+
weight_jax,
|
|
192
|
+
(((1, ), (1, )), ((), ())),
|
|
193
|
+
preferred_element_type=jnp.float32,
|
|
194
|
+
)
|
|
195
|
+
# TODO(kyuyeunk): Investigate performance gain from merging scales.
|
|
196
|
+
# out *= weight_scale_jax
|
|
197
|
+
out *= weight_scale_jax * input_scale
|
|
198
|
+
out = out.astype(x_jax.dtype)
|
|
199
|
+
else:
|
|
200
|
+
out = sharded_quantized_matmul(x_jax, weight_jax,
|
|
201
|
+
weight_scale_jax,
|
|
202
|
+
self.jax_config.mesh,
|
|
203
|
+
self.jax_config.weight_sharding)
|
|
204
|
+
|
|
205
|
+
if bias is not None and not layer.skip_bias_add:
|
|
206
|
+
out += jax_view(bias[i])
|
|
207
|
+
outs.append(out)
|
|
208
|
+
return torch_view(jnp.concatenate(outs, axis=-1))
|