tpu-inference 0.11.1rc2__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (50) hide show
  1. tpu_inference/kernels/collectives/__init__.py +0 -0
  2. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  3. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  4. tpu_inference/kernels/collectives/util.py +47 -0
  5. tpu_inference/layers/__init__.py +0 -0
  6. tpu_inference/layers/common/__init__.py +0 -0
  7. tpu_inference/layers/common/attention_metadata.py +34 -0
  8. tpu_inference/layers/jax/__init__.py +0 -0
  9. tpu_inference/layers/jax/attention/__init__.py +0 -0
  10. tpu_inference/layers/jax/attention/attention.py +254 -0
  11. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  12. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  13. tpu_inference/layers/jax/attention_interface.py +356 -0
  14. tpu_inference/layers/jax/base.py +151 -0
  15. tpu_inference/layers/jax/binary_search.py +295 -0
  16. tpu_inference/layers/jax/constants.py +88 -0
  17. tpu_inference/layers/jax/layers.py +301 -0
  18. tpu_inference/layers/jax/misc.py +16 -0
  19. tpu_inference/layers/jax/moe/__init__.py +0 -0
  20. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  21. tpu_inference/layers/jax/moe/moe.py +209 -0
  22. tpu_inference/layers/jax/rope.py +172 -0
  23. tpu_inference/layers/jax/rope_interface.py +214 -0
  24. tpu_inference/layers/jax/sample/__init__.py +0 -0
  25. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  26. tpu_inference/layers/jax/sample/sampling.py +95 -0
  27. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  28. tpu_inference/layers/jax/sharding.py +406 -0
  29. tpu_inference/layers/jax/transformer_block.py +76 -0
  30. tpu_inference/layers/vllm/__init__.py +0 -0
  31. tpu_inference/layers/vllm/attention.py +184 -0
  32. tpu_inference/layers/vllm/fused_moe.py +399 -0
  33. tpu_inference/layers/vllm/linear_common.py +186 -0
  34. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  35. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  36. tpu_inference/layers/vllm/quantization/common.py +105 -0
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  38. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  39. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  40. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  41. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  42. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  43. tpu_inference/layers/vllm/sharding.py +151 -0
  44. tpu_inference/models/common/__init__.py +0 -0
  45. tpu_inference/models/common/model_loader.py +433 -0
  46. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
  47. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
  48. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +1 -1
  49. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
  50. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,207 @@
1
+ from typing import Optional, Union
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ from jax.sharding import NamedSharding, PartitionSpec
7
+ from torchax.interop import jax_view, torch_view
8
+ from vllm.logger import init_logger
9
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
10
+ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
11
+ from vllm.model_executor.layers.quantization import \
12
+ register_quantization_config
13
+ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
14
+ AWQLinearMethod,
15
+ is_layer_skipped_awq)
16
+ from vllm.model_executor.layers.quantization.base_config import \
17
+ QuantizeMethodBase
18
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
19
+ unpack_quantized_values_into_int32
20
+ from vllm.scalar_type import scalar_types
21
+
22
+ from tpu_inference.layers.vllm.linear_common import (
23
+ slice_sharded_tensor_for_concatenation, torch_to_jax_param)
24
+ from tpu_inference.layers.vllm.quantization.common import (
25
+ JaxCommonConfig, JaxCommonLinearConfig)
26
+ from tpu_inference.layers.vllm.quantization.unquantized import \
27
+ VllmUnquantizedLinearMethod
28
+
29
+ P = PartitionSpec
30
+ logger = init_logger(__name__)
31
+
32
+
33
+ @register_quantization_config("jax-awq")
34
+ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
35
+
36
+ @classmethod
37
+ def get_name(cls) -> str:
38
+ return "jax-awq"
39
+
40
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
41
+ # NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
42
+ # bfloat16 is signifcantly preferred over foat16. This might lead to
43
+ # some numeric output change.
44
+ return [torch.bfloat16]
45
+
46
+ def get_quant_method(
47
+ self, layer: torch.nn.Module, prefix: str
48
+ ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
49
+ if isinstance(layer, LinearBase):
50
+ linear_config = self.get_linear_config(layer)
51
+ if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
52
+ return VllmUnquantizedLinearMethod(linear_config)
53
+ return VllmAWQLinearMethod(self, linear_config)
54
+ elif isinstance(layer, FusedMoE):
55
+ raise NotImplementedError(
56
+ "AWQ FusedMoE is currently not supported in torchax-jax")
57
+ return None
58
+
59
+
60
+ class VllmAWQLinearMethod(AWQLinearMethod):
61
+
62
+ def __init__(self, quant_config: VllmAWQConfig,
63
+ jax_config: JaxCommonLinearConfig):
64
+ super().__init__(quant_config)
65
+ self.jax_config = jax_config
66
+
67
+ out_sharding, in_sharding = self.jax_config.weight_sharding[:]
68
+ self.jax_config.weight_sharding = P(in_sharding, None, out_sharding)
69
+ self.jax_config.scale_sharding = P(in_sharding, out_sharding)
70
+
71
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
72
+ qweight = layer.qweight
73
+ qweight = unpack_awq_weight(qweight, qweight.packed_dim)
74
+
75
+ group_size = self.quant_config.group_size
76
+ # Reshape so that each qweight[i] were quantized with same scales[i].
77
+ qweight = qweight.reshape((-1, group_size, layer.output_size))
78
+ qweight = torch_to_jax_param(qweight,
79
+ NamedSharding(
80
+ self.jax_config.mesh,
81
+ self.jax_config.weight_sharding),
82
+ self.jax_config.output_sizes,
83
+ self.jax_config.n_shards,
84
+ self.jax_config.fuse_matmuls,
85
+ dim=2,
86
+ jax_dtype=jnp.uint4)
87
+ delattr(layer, "qweight")
88
+ layer.qweight = qweight
89
+
90
+ qzeros = layer.qzeros
91
+ qzeros = unpack_awq_weight(qzeros, qzeros.packed_dim)
92
+ qzeros = torch_to_jax_param(qzeros,
93
+ NamedSharding(
94
+ self.jax_config.mesh,
95
+ self.jax_config.scale_sharding),
96
+ self.jax_config.output_sizes,
97
+ self.jax_config.n_shards,
98
+ self.jax_config.fuse_matmuls,
99
+ dim=1,
100
+ jax_dtype=jnp.uint4)
101
+ delattr(layer, "qzeros")
102
+ layer.qzeros = qzeros
103
+
104
+ scales = torch_to_jax_param(layer.scales,
105
+ NamedSharding(
106
+ self.jax_config.mesh,
107
+ self.jax_config.scale_sharding),
108
+ self.jax_config.output_sizes,
109
+ self.jax_config.n_shards,
110
+ self.jax_config.fuse_matmuls,
111
+ dim=1)
112
+ delattr(layer, "scales")
113
+ layer.scales = scales
114
+
115
+ if layer.bias is not None and not layer.skip_bias_add:
116
+ if layer.return_bias:
117
+ logger.warning_once("Bias might return incorrect value.")
118
+
119
+ bias = torch_to_jax_param(
120
+ layer.bias,
121
+ NamedSharding(self.jax_config.mesh,
122
+ self.jax_config.bias_sharding),
123
+ self.jax_config.output_sizes,
124
+ self.jax_config.n_shards,
125
+ self.jax_config.fuse_matmuls,
126
+ )
127
+ delattr(layer, "bias")
128
+ layer.bias = bias
129
+
130
+ def apply(self,
131
+ layer: torch.nn.Module,
132
+ x: torch.Tensor,
133
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
134
+
135
+ with jax.named_scope(layer._get_name()):
136
+ if self.jax_config.fuse_matmuls:
137
+ out = self._apply_fused(layer, x, bias)
138
+ else:
139
+ out = self._apply_split(layer, x, bias)
140
+
141
+ return out
142
+
143
+ def _apply_fused(self,
144
+ layer: torch.nn.Module,
145
+ x: torch.Tensor,
146
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
147
+ x_jax = jax_view(x)
148
+
149
+ qweight = jax_view(layer.qweight)
150
+ qzeros = jnp.expand_dims(jax_view(layer.qzeros), 1)
151
+ scales = jnp.expand_dims(jax_view(layer.scales), 1)
152
+
153
+ qweight = qweight.astype(jnp.int8)
154
+ qzeros = qzeros.astype(jnp.int8)
155
+
156
+ weight = (qweight - qzeros) * scales
157
+ weight = weight.reshape((-1, weight.shape[-1]))
158
+ outs = jnp.einsum("bd,df->bf", x_jax, weight)
159
+
160
+ if bias is not None and not layer.skip_bias_add:
161
+ outs += bias.jax()
162
+
163
+ outs = slice_sharded_tensor_for_concatenation(
164
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
165
+ out = jnp.concatenate(outs, axis=-1)
166
+ return torch_view(out)
167
+
168
+ def _apply_split(self,
169
+ layer: torch.nn.Module,
170
+ x: torch.Tensor,
171
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
172
+ assert isinstance(layer.qweight, torch.nn.ParameterList)
173
+
174
+ x_jax = jax_view(x)
175
+ params = zip(layer.qweight, layer.qzeros, layer.scales)
176
+ outs = []
177
+ for i, (qweight, qzeros, scales) in enumerate(params):
178
+ qweight = jax_view(qweight)
179
+ scales = jnp.expand_dims(jax_view(scales), 1)
180
+ qzeros = jnp.expand_dims(jax_view(qzeros), 1)
181
+
182
+ qweight = qweight.astype(jnp.int8)
183
+ qzeros = qzeros.astype(jnp.int8)
184
+
185
+ weight = (qweight - qzeros) * scales
186
+ weight = weight.reshape((-1, weight.shape[-1]))
187
+ out = jnp.einsum("bd,df->bf", x_jax, weight)
188
+
189
+ if bias is not None and not layer.skip_bias_add:
190
+ out += jax_view(bias[i])
191
+
192
+ outs.append(out)
193
+ out = jnp.concatenate(outs, axis=-1)
194
+ return torch_view(out)
195
+
196
+
197
+ def unpack_awq_weight(weight: torch.Tensor, packed_dim: int):
198
+ weight = unpack_quantized_values_into_int32(weight, scalar_types.uint4,
199
+ packed_dim)
200
+
201
+ # AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
202
+ # Following list maps the order used by AWQ into an ascending order.
203
+ reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
204
+
205
+ orig_shape = weight.shape
206
+ weight = weight.reshape(orig_shape[:-1] + (-1, 8))
207
+ return weight[..., reverse_awq_order].reshape(orig_shape)
@@ -0,0 +1,105 @@
1
+ import torchax
2
+ from jax.sharding import Mesh, PartitionSpec
3
+ from vllm.config import VllmConfig
4
+ from vllm.logger import init_logger
5
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEConfig
6
+ # yapf: disable
7
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
8
+ LinearBase,
9
+ MergedColumnParallelLinear,
10
+ QKVParallelLinear,
11
+ ReplicatedLinear,
12
+ RowParallelLinear)
13
+
14
+ from tpu_inference.layers.vllm.linear_common import \
15
+ get_model_matmul_fusion_assignment
16
+ from tpu_inference.utils import TPU_SECOND_LAST_MINOR
17
+
18
+ # yapf: enable
19
+
20
+ P = PartitionSpec
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ class JaxCommonLinearConfig:
26
+
27
+ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
28
+ assert isinstance(layer, LinearBase)
29
+
30
+ self.mesh = mesh
31
+ self.output_sizes = [layer.output_size]
32
+ self.weight_sharding = P(None, None)
33
+ self.fuse_matmuls = True
34
+ self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
35
+ self.input_sharding = None
36
+ self.output_sharding = None
37
+
38
+ if isinstance(layer, RowParallelLinear):
39
+ self.weight_sharding = P(None, "model")
40
+ if self.enable_sequence_parallelism:
41
+ self.output_sharding = P("model", None)
42
+ elif isinstance(layer, ColumnParallelLinear):
43
+ self.weight_sharding = P("model", None)
44
+ if self.enable_sequence_parallelism:
45
+ self.input_sharding = P("model", None)
46
+
47
+ if isinstance(layer, MergedColumnParallelLinear) or isinstance(
48
+ layer, QKVParallelLinear):
49
+ self.output_sizes = layer.output_sizes
50
+
51
+ self.fuse_matmuls = get_model_matmul_fusion_assignment(
52
+ vllm_config.model_config.model,
53
+ vllm_config.scheduler_config.max_num_batched_tokens,
54
+ vllm_config.parallel_config.tensor_parallel_size,
55
+ layer._get_name())
56
+ elif isinstance(layer, ReplicatedLinear):
57
+ self.weight_sharding = P(None, None)
58
+ else:
59
+ logger.warning(
60
+ "Unsupported linear layer type of %s. Can potentially yield "
61
+ " bad performance.", type(layer))
62
+
63
+ self.bias_sharding = P(self.weight_sharding[0])
64
+ self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
65
+
66
+ def get_input_sharding(self, x: torchax.tensor.Tensor):
67
+ if self.enable_sequence_parallelism:
68
+ token_num = x.shape[0]
69
+ # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
70
+ if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
71
+ return self.input_sharding
72
+ else:
73
+ return None
74
+ return self.input_sharding
75
+
76
+ def get_output_sharding(self, x: torchax.tensor.Tensor):
77
+ if self.enable_sequence_parallelism:
78
+ token_num = x.shape[0]
79
+ # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
80
+ if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
81
+ return self.output_sharding
82
+ else:
83
+ return None
84
+ return self.output_sharding
85
+
86
+
87
+ class JaxCommonConfig:
88
+ vllm_config: VllmConfig
89
+ mesh: Mesh
90
+
91
+ @classmethod
92
+ def set_configs(cls, vllm_config: VllmConfig, mesh: Mesh):
93
+ cls.vllm_config = vllm_config
94
+ cls.mesh = mesh
95
+
96
+ def get_linear_config(self, layer: LinearBase) -> JaxCommonLinearConfig:
97
+ assert isinstance(layer, LinearBase)
98
+ return JaxCommonLinearConfig(self.vllm_config, self.mesh, layer)
99
+
100
+ def get_moe_config(self, layer: FusedMoE) -> FusedMoEConfig:
101
+ assert isinstance(layer, FusedMoE)
102
+ moe_config = layer.moe_config
103
+ use_ep = self.vllm_config.parallel_config.enable_expert_parallel
104
+ moe_config.moe_parallel_config.use_ep = use_ep
105
+ return moe_config
@@ -0,0 +1,121 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from jax.sharding import PartitionSpec
5
+ from vllm.attention.layer import Attention
6
+ from vllm.logger import init_logger
7
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
8
+ from vllm.model_executor.layers.linear import LinearBase
9
+ from vllm.model_executor.layers.quantization import \
10
+ register_quantization_config
11
+ from vllm.model_executor.layers.quantization.base_config import \
12
+ QuantizeMethodBase # noqa: E501
13
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
14
+ CompressedTensorsConfig, CompressedTensorsKVCacheMethod,
15
+ CompressedTensorsLinearMethod, CompressedTensorsScheme)
16
+ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
17
+ find_matched_target, is_activation_quantization_format,
18
+ should_ignore_layer)
19
+
20
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
21
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
22
+ VllmCompressedTensorsW8A8Fp8
23
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
24
+ VllmCompressedTensorsW8A8Int8
25
+ from tpu_inference.layers.vllm.quantization.unquantized import \
26
+ VllmUnquantizedConfig
27
+
28
+ P = PartitionSpec
29
+ logger = init_logger(__name__)
30
+
31
+
32
+ @register_quantization_config("jax-compressed-tensors")
33
+ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
34
+
35
+ @classmethod
36
+ def get_name(cls) -> str:
37
+ return "jax-compressed-tensors"
38
+
39
+ def get_scheme(self,
40
+ layer: torch.nn.Module,
41
+ layer_name: Optional[str] = None
42
+ ) -> Optional["CompressedTensorsScheme"]:
43
+ """
44
+ compressed-tensors supports non uniform in the following way:
45
+
46
+ targets of config_groups: There can be N config_groups which each
47
+ have a quantization scheme. Each config_group has a list of targets
48
+ which can be a full layer_name, a regex for a layer_name, or
49
+ an nn.Module name.
50
+
51
+ Detect whether a layer_name is found in any target and
52
+ use the quantization scheme corresponding to the matched target
53
+ to select the CompressedTensorsScheme used for inference.
54
+ """
55
+
56
+ # Will be empty for models with only sparsity
57
+ weight_quant = input_quant = None
58
+ if self.target_scheme_map:
59
+ matched_target = find_matched_target(
60
+ layer_name=layer_name,
61
+ module=layer,
62
+ targets=self.target_scheme_map.keys(),
63
+ fused_mapping=self.packed_modules_mapping)
64
+
65
+ scheme_dict = self.target_scheme_map[matched_target]
66
+ weight_quant = scheme_dict.get("weights")
67
+ input_quant = scheme_dict.get("input_activations")
68
+ format = scheme_dict.get("format")
69
+
70
+ if weight_quant is None:
71
+ logger.warning_once("Acceleration for non-quantized schemes is "
72
+ "not supported by Compressed Tensors. "
73
+ "Falling back to UnquantizedLinearMethod")
74
+ return None
75
+
76
+ # TODO(kyuyeunk): Add support for different act_quant_format
77
+ act_quant_format = is_activation_quantization_format( # noqa: F841
78
+ format
79
+ ) if format is not None else is_activation_quantization_format(
80
+ self.quant_format)
81
+
82
+ linear_config = self.get_linear_config(layer)
83
+ if self._is_fp8_w8a8(weight_quant, input_quant):
84
+ is_static_input_scheme = input_quant and not input_quant.dynamic
85
+ return VllmCompressedTensorsW8A8Fp8(
86
+ weight_quant=weight_quant,
87
+ is_static_input_scheme=is_static_input_scheme,
88
+ jax_config=linear_config,
89
+ )
90
+ if self._is_dynamic_token_w8a8(weight_quant, input_quant):
91
+ return VllmCompressedTensorsW8A8Int8(
92
+ strategy=weight_quant.strategy,
93
+ is_static_input_scheme=False,
94
+ input_symmetric=input_quant.symmetric,
95
+ jax_config=linear_config,
96
+ )
97
+ raise NotImplementedError(
98
+ "No compressed-tensors compatible scheme was found.")
99
+
100
+ def get_quant_method(
101
+ self,
102
+ layer: torch.nn.Module,
103
+ prefix: str,
104
+ ) -> Optional[QuantizeMethodBase]:
105
+ if should_ignore_layer(prefix,
106
+ ignore=self.ignore,
107
+ fused_mapping=self.packed_modules_mapping):
108
+ return VllmUnquantizedConfig.get_quant_method(self, layer, prefix)
109
+ if isinstance(layer, LinearBase):
110
+ scheme = self.get_scheme(layer=layer, layer_name=prefix)
111
+ if scheme is None:
112
+ return VllmUnquantizedConfig.get_quant_method(
113
+ self, layer, prefix)
114
+ layer.scheme = scheme
115
+ return CompressedTensorsLinearMethod(self)
116
+ if isinstance(layer, FusedMoE):
117
+ raise NotImplementedError(
118
+ "FusedMoE quantization is currently not supported.")
119
+ if isinstance(layer, Attention):
120
+ return CompressedTensorsKVCacheMethod(self)
121
+ return None
@@ -0,0 +1,208 @@
1
+ from typing import Optional
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ from compressed_tensors.quantization import (QuantizationArgs,
7
+ QuantizationStrategy)
8
+ from jax.sharding import NamedSharding, PartitionSpec
9
+ from torchax.interop import jax_view, torch_view
10
+ from torchax.ops.mappings import t2j
11
+ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
12
+ CompressedTensorsW8A8Fp8
13
+ from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
14
+ per_tensor_dequantize
15
+
16
+ from tpu_inference.layers.vllm.linear_common import (
17
+ sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
18
+ torch_to_jax_param)
19
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
20
+
21
+ P = PartitionSpec
22
+
23
+
24
+ def requantize_with_max_scale(
25
+ weight: torch.Tensor, weight_scale: torch.Tensor,
26
+ logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
27
+ dtype = weight.dtype
28
+ dtype_info = torch.finfo(dtype)
29
+ maxval = float(dtype_info.max)
30
+ minval = float(dtype_info.min)
31
+
32
+ max_w_scale = weight_scale.max()
33
+
34
+ unfused_module_in_checkpoint = (weight_scale[-1]
35
+ > torch.finfo(torch.float8_e4m3fn).min)
36
+
37
+ # If unfused checkpoint, need requanize with the single scale.
38
+ if unfused_module_in_checkpoint:
39
+ start = 0
40
+ for idx, logical_width in enumerate(logical_widths):
41
+ # Skip any component with zero width.
42
+ if logical_width == 0:
43
+ continue
44
+ end = start + logical_width
45
+ weight_dq = per_tensor_dequantize(weight[start:end, :],
46
+ weight_scale[idx])
47
+ weight_q = weight_dq / max_w_scale
48
+ weight[start:end, :] = weight_q.clamp(minval, maxval).to(dtype)
49
+ start = end
50
+
51
+ return max_w_scale, weight
52
+
53
+
54
+ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
55
+
56
+ def __init__(
57
+ self,
58
+ weight_quant: QuantizationArgs,
59
+ is_static_input_scheme: bool,
60
+ jax_config: JaxCommonLinearConfig,
61
+ ):
62
+ super().__init__(weight_quant, is_static_input_scheme)
63
+
64
+ self.jax_config = jax_config
65
+
66
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
67
+ weight = layer.weight
68
+ weight_scale = layer.weight_scale
69
+
70
+ if self.is_static_input_scheme:
71
+ # In static quant, all input_scales share the same value.
72
+ assert layer.input_scale.min() == layer.input_scale.max()
73
+ input_scale_first = layer.input_scale[0]
74
+
75
+ input_scale = jax.device_put(
76
+ t2j(input_scale_first, use_dlpack=False),
77
+ NamedSharding(self.jax_config.mesh, P()))
78
+ input_scale = torch.nn.Parameter(torch_view(input_scale),
79
+ requires_grad=False)
80
+ delattr(layer, "input_scale")
81
+ layer.input_scale = input_scale
82
+
83
+ # TODO(kyuyeunk): Investigate performance gain from merging scales.
84
+ # By merging input and weight scales, we reduce the number of muls
85
+ # required for dequantization from 2 (for each scales) to 1.
86
+ # weight_scale *= input_scale_first
87
+
88
+ if self.strategy == QuantizationStrategy.TENSOR:
89
+ weight_scale, weight = requantize_with_max_scale(
90
+ weight, weight_scale, self.jax_config.output_sizes)
91
+ weight_scale = jax.device_put(
92
+ t2j(weight_scale, use_dlpack=False),
93
+ NamedSharding(self.jax_config.mesh, P()))
94
+ weight_scale = torch.nn.Parameter(torch_view(weight_scale),
95
+ requires_grad=False)
96
+ else:
97
+ weight_scale = weight_scale.squeeze(-1)
98
+ weight_scale = torch_to_jax_param(
99
+ weight_scale,
100
+ NamedSharding(self.jax_config.mesh,
101
+ self.jax_config.bias_sharding),
102
+ self.jax_config.output_sizes, self.jax_config.n_shards,
103
+ self.jax_config.fuse_matmuls)
104
+ delattr(layer, "weight_scale")
105
+ layer.weight_scale = weight_scale
106
+
107
+ weight = torch_to_jax_param(
108
+ layer.weight,
109
+ NamedSharding(self.jax_config.mesh,
110
+ self.jax_config.weight_sharding),
111
+ self.jax_config.output_sizes, self.jax_config.n_shards,
112
+ self.jax_config.fuse_matmuls)
113
+ delattr(layer, "weight")
114
+ layer.weight = weight
115
+
116
+ if layer.bias is not None:
117
+ bias = torch_to_jax_param(
118
+ layer.bias,
119
+ NamedSharding(self.jax_config.mesh,
120
+ self.jax_config.bias_sharding),
121
+ self.jax_config.output_sizes, self.jax_config.n_shards,
122
+ self.jax_config.fuse_matmuls)
123
+ delattr(layer, "bias")
124
+ layer.bias = bias
125
+
126
+ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
127
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
128
+ with jax.named_scope(layer._get_name()):
129
+ if self.jax_config.fuse_matmuls:
130
+ return self._apply_fused(layer, x, bias)
131
+ else:
132
+ return self._apply_split(layer, x, bias)
133
+
134
+ def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
135
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
136
+ x_jax = jax_view(x)
137
+ weight_jax = jax_view(layer.weight)
138
+ weight_scale_jax = jax_view(layer.weight_scale)
139
+
140
+ if self.is_static_input_scheme:
141
+ # TODO(kyuyeunk): Add kernel support for static quant
142
+ input_scale = jax_view(layer.input_scale)
143
+ dtype_info = jnp.finfo(weight_jax.dtype)
144
+ maxval = float(dtype_info.max)
145
+ minval = float(dtype_info.min)
146
+ x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
147
+ maxval).astype(weight_jax.dtype)
148
+
149
+ outs = jax.lax.dot_general(
150
+ x_q,
151
+ weight_jax,
152
+ (((1, ), (1, )), ((), ())),
153
+ preferred_element_type=jnp.float32,
154
+ )
155
+ outs *= weight_scale_jax
156
+ outs = outs.astype(x_jax.dtype)
157
+ else:
158
+ outs = sharded_quantized_matmul(x_jax, weight_jax,
159
+ weight_scale_jax,
160
+ self.jax_config.mesh,
161
+ self.jax_config.weight_sharding)
162
+
163
+ if bias is not None and not layer.skip_bias_add:
164
+ outs += jax_view(bias)
165
+ outs = slice_sharded_tensor_for_concatenation(
166
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
167
+ return torch_view(jnp.concatenate(outs, axis=-1))
168
+
169
+ def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
170
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
171
+ assert isinstance(layer.weight, torch.nn.ParameterList)
172
+
173
+ x_jax = jax_view(x)
174
+ outs = []
175
+ for i, (weight, weight_scale) in enumerate(
176
+ zip(layer.weight, layer.weight_scale)):
177
+ weight_jax = jax_view(weight)
178
+ weight_scale_jax = jax_view(weight_scale)
179
+
180
+ if self.is_static_input_scheme:
181
+ # TODO(kyuyeunk): Add kernel support for static quant
182
+ input_scale = jax_view(layer.input_scale)
183
+ dtype_info = jnp.finfo(weight_jax.dtype)
184
+ maxval = float(dtype_info.max)
185
+ minval = float(dtype_info.min)
186
+ x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
187
+ maxval).astype(weight_jax.dtype)
188
+
189
+ out = jax.lax.dot_general(
190
+ x_q,
191
+ weight_jax,
192
+ (((1, ), (1, )), ((), ())),
193
+ preferred_element_type=jnp.float32,
194
+ )
195
+ # TODO(kyuyeunk): Investigate performance gain from merging scales.
196
+ # out *= weight_scale_jax
197
+ out *= weight_scale_jax * input_scale
198
+ out = out.astype(x_jax.dtype)
199
+ else:
200
+ out = sharded_quantized_matmul(x_jax, weight_jax,
201
+ weight_scale_jax,
202
+ self.jax_config.mesh,
203
+ self.jax_config.weight_sharding)
204
+
205
+ if bias is not None and not layer.skip_bias_add:
206
+ out += jax_view(bias[i])
207
+ outs.append(out)
208
+ return torch_view(jnp.concatenate(outs, axis=-1))