tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,39 @@
1
+ import copy
2
+
3
+ from jax.sharding import Mesh
4
+ from vllm.config import VllmConfig
5
+ from vllm.model_executor.layers.quantization.base_config import \
6
+ QuantizationConfig
7
+
8
+ from tpu_inference.layers.common import quant_methods
9
+ from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
10
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
11
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
12
+ VllmCompressedTensorsConfig # noqa: E501
13
+ from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
14
+ from tpu_inference.layers.vllm.quantization.unquantized import \
15
+ VllmUnquantizedConfig
16
+
17
+
18
+ def get_tpu_quantization_config(vllm_config: VllmConfig,
19
+ mesh: Mesh) -> QuantizationConfig:
20
+ model_config = copy.deepcopy(vllm_config.model_config)
21
+ # TODO(kyuyeunk): Add support for "tpu_int8".
22
+ method_to_config: dict[str, str] = {
23
+ None: VllmUnquantizedConfig,
24
+ quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
25
+ quant_methods.AWQ: VllmAWQConfig,
26
+ quant_methods.MXFP4: VllmMxfp4Config,
27
+ }
28
+ if model_config.quantization not in method_to_config:
29
+ raise NotImplementedError(
30
+ f"{model_config.quantization} quantization method not supported."
31
+ f" Supported methods are {method_to_config.keys()}")
32
+ quant_config = method_to_config[model_config.quantization]
33
+ assert issubclass(quant_config, JaxCommonConfig)
34
+ quant_config.set_configs(vllm_config, mesh)
35
+
36
+ model_config.quantization = quant_methods.get_tpu_quant_method(
37
+ quant_config.get_name())
38
+ return VllmConfig.get_quantization_config(model_config,
39
+ vllm_config.load_config)
@@ -0,0 +1,207 @@
1
+ from typing import Optional, Union
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ from jax.sharding import NamedSharding, PartitionSpec
7
+ from torchax.interop import jax_view, torch_view
8
+ from vllm.logger import init_logger
9
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
10
+ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
11
+ from vllm.model_executor.layers.quantization import \
12
+ register_quantization_config
13
+ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
14
+ AWQLinearMethod)
15
+ from vllm.model_executor.layers.quantization.base_config import \
16
+ QuantizeMethodBase
17
+ from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
+ is_layer_skipped, unpack_quantized_values_into_int32)
19
+ from vllm.scalar_type import scalar_types
20
+
21
+ from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
22
+ from tpu_inference.layers.vllm.linear_common import (
23
+ slice_sharded_tensor_for_concatenation, torch_to_jax_param)
24
+ from tpu_inference.layers.vllm.quantization.common import (
25
+ JaxCommonConfig, JaxCommonLinearConfig)
26
+ from tpu_inference.layers.vllm.quantization.unquantized import \
27
+ VllmUnquantizedLinearMethod
28
+
29
+ P = PartitionSpec
30
+ logger = init_logger(__name__)
31
+
32
+
33
+ @register_quantization_config(get_tpu_quant_method(AWQ))
34
+ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
35
+
36
+ @classmethod
37
+ def get_name(cls):
38
+ return AWQ
39
+
40
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
41
+ # NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
42
+ # bfloat16 is signifcantly preferred over foat16. This might lead to
43
+ # some numeric output change.
44
+ return [torch.bfloat16]
45
+
46
+ def get_quant_method(
47
+ self, layer: torch.nn.Module, prefix: str
48
+ ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
49
+ if isinstance(layer, LinearBase):
50
+ linear_config = self.get_linear_config(layer)
51
+ if is_layer_skipped(prefix, self.modules_to_not_convert):
52
+ return VllmUnquantizedLinearMethod(linear_config)
53
+ return VllmAWQLinearMethod(self, linear_config)
54
+ elif isinstance(layer, FusedMoE):
55
+ raise NotImplementedError(
56
+ "AWQ FusedMoE is currently not supported in torchax-jax")
57
+ return None
58
+
59
+
60
+ class VllmAWQLinearMethod(AWQLinearMethod):
61
+
62
+ def __init__(self, quant_config: VllmAWQConfig,
63
+ jax_config: JaxCommonLinearConfig):
64
+ super().__init__(quant_config)
65
+ self.jax_config = jax_config
66
+
67
+ out_sharding, in_sharding = self.jax_config.weight_sharding[:]
68
+ self.jax_config.weight_sharding = P(in_sharding, None, out_sharding)
69
+ self.jax_config.scale_sharding = P(in_sharding, out_sharding)
70
+
71
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
72
+ qweight = layer.qweight
73
+ qweight = unpack_awq_weight(qweight, qweight.packed_dim)
74
+
75
+ group_size = self.quant_config.group_size
76
+ # Reshape so that each qweight[i] were quantized with same scales[i].
77
+ qweight = qweight.reshape((-1, group_size, layer.output_size))
78
+ qweight = torch_to_jax_param(qweight,
79
+ NamedSharding(
80
+ self.jax_config.mesh,
81
+ self.jax_config.weight_sharding),
82
+ self.jax_config.output_sizes,
83
+ self.jax_config.n_shards,
84
+ self.jax_config.fuse_matmuls,
85
+ dim=2,
86
+ jax_dtype=jnp.uint4)
87
+ delattr(layer, "qweight")
88
+ layer.qweight = qweight
89
+
90
+ qzeros = layer.qzeros
91
+ qzeros = unpack_awq_weight(qzeros, qzeros.packed_dim)
92
+ qzeros = torch_to_jax_param(qzeros,
93
+ NamedSharding(
94
+ self.jax_config.mesh,
95
+ self.jax_config.scale_sharding),
96
+ self.jax_config.output_sizes,
97
+ self.jax_config.n_shards,
98
+ self.jax_config.fuse_matmuls,
99
+ dim=1,
100
+ jax_dtype=jnp.uint4)
101
+ delattr(layer, "qzeros")
102
+ layer.qzeros = qzeros
103
+
104
+ scales = torch_to_jax_param(layer.scales,
105
+ NamedSharding(
106
+ self.jax_config.mesh,
107
+ self.jax_config.scale_sharding),
108
+ self.jax_config.output_sizes,
109
+ self.jax_config.n_shards,
110
+ self.jax_config.fuse_matmuls,
111
+ dim=1)
112
+ delattr(layer, "scales")
113
+ layer.scales = scales
114
+
115
+ if layer.bias is not None and not layer.skip_bias_add:
116
+ if layer.return_bias:
117
+ logger.warning_once("Bias might return incorrect value.")
118
+
119
+ bias = torch_to_jax_param(
120
+ layer.bias,
121
+ NamedSharding(self.jax_config.mesh,
122
+ self.jax_config.bias_sharding),
123
+ self.jax_config.output_sizes,
124
+ self.jax_config.n_shards,
125
+ self.jax_config.fuse_matmuls,
126
+ )
127
+ delattr(layer, "bias")
128
+ layer.bias = bias
129
+
130
+ def apply(self,
131
+ layer: torch.nn.Module,
132
+ x: torch.Tensor,
133
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
134
+
135
+ with jax.named_scope(layer._get_name()):
136
+ if self.jax_config.fuse_matmuls:
137
+ out = self._apply_fused(layer, x, bias)
138
+ else:
139
+ out = self._apply_split(layer, x, bias)
140
+
141
+ return out
142
+
143
+ def _apply_fused(self,
144
+ layer: torch.nn.Module,
145
+ x: torch.Tensor,
146
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
147
+ x_jax = jax_view(x)
148
+
149
+ qweight = jax_view(layer.qweight)
150
+ qzeros = jnp.expand_dims(jax_view(layer.qzeros), 1)
151
+ scales = jnp.expand_dims(jax_view(layer.scales), 1)
152
+
153
+ qweight = qweight.astype(jnp.int8)
154
+ qzeros = qzeros.astype(jnp.int8)
155
+
156
+ weight = (qweight - qzeros) * scales
157
+ weight = weight.reshape((-1, weight.shape[-1]))
158
+ outs = jnp.einsum("bd,df->bf", x_jax, weight)
159
+
160
+ if bias is not None and not layer.skip_bias_add:
161
+ outs += bias.jax()
162
+
163
+ outs = slice_sharded_tensor_for_concatenation(
164
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
165
+ out = jnp.concatenate(outs, axis=-1)
166
+ return torch_view(out)
167
+
168
+ def _apply_split(self,
169
+ layer: torch.nn.Module,
170
+ x: torch.Tensor,
171
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
172
+ assert isinstance(layer.qweight, torch.nn.ParameterList)
173
+
174
+ x_jax = jax_view(x)
175
+ params = zip(layer.qweight, layer.qzeros, layer.scales)
176
+ outs = []
177
+ for i, (qweight, qzeros, scales) in enumerate(params):
178
+ qweight = jax_view(qweight)
179
+ scales = jnp.expand_dims(jax_view(scales), 1)
180
+ qzeros = jnp.expand_dims(jax_view(qzeros), 1)
181
+
182
+ qweight = qweight.astype(jnp.int8)
183
+ qzeros = qzeros.astype(jnp.int8)
184
+
185
+ weight = (qweight - qzeros) * scales
186
+ weight = weight.reshape((-1, weight.shape[-1]))
187
+ out = jnp.einsum("bd,df->bf", x_jax, weight)
188
+
189
+ if bias is not None and not layer.skip_bias_add:
190
+ out += jax_view(bias[i])
191
+
192
+ outs.append(out)
193
+ out = jnp.concatenate(outs, axis=-1)
194
+ return torch_view(out)
195
+
196
+
197
+ def unpack_awq_weight(weight: torch.Tensor, packed_dim: int):
198
+ weight = unpack_quantized_values_into_int32(weight, scalar_types.uint4,
199
+ packed_dim)
200
+
201
+ # AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
202
+ # Following list maps the order used by AWQ into an ascending order.
203
+ reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
204
+
205
+ orig_shape = weight.shape
206
+ weight = weight.reshape(orig_shape[:-1] + (-1, 8))
207
+ return weight[..., reverse_awq_order].reshape(orig_shape)
@@ -0,0 +1,105 @@
1
+ import torchax
2
+ from jax.sharding import Mesh, PartitionSpec
3
+ from vllm.config import VllmConfig
4
+ from vllm.logger import init_logger
5
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEConfig
6
+ # yapf: disable
7
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
8
+ LinearBase,
9
+ MergedColumnParallelLinear,
10
+ QKVParallelLinear,
11
+ ReplicatedLinear,
12
+ RowParallelLinear)
13
+
14
+ from tpu_inference.layers.vllm.linear_common import \
15
+ get_model_matmul_fusion_assignment
16
+ from tpu_inference.utils import TPU_SECOND_LAST_MINOR
17
+
18
+ # yapf: enable
19
+
20
+ P = PartitionSpec
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ class JaxCommonLinearConfig:
26
+
27
+ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
28
+ assert isinstance(layer, LinearBase)
29
+
30
+ self.mesh = mesh
31
+ self.output_sizes = [layer.output_size]
32
+ self.weight_sharding = P(None, None)
33
+ self.fuse_matmuls = True
34
+ self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
35
+ self.input_sharding = None
36
+ self.output_sharding = None
37
+
38
+ if isinstance(layer, RowParallelLinear):
39
+ self.weight_sharding = P(None, "model")
40
+ if self.enable_sequence_parallelism:
41
+ self.output_sharding = P("model", None)
42
+ elif isinstance(layer, ColumnParallelLinear):
43
+ self.weight_sharding = P("model", None)
44
+ if self.enable_sequence_parallelism:
45
+ self.input_sharding = P("model", None)
46
+
47
+ if isinstance(layer, MergedColumnParallelLinear) or isinstance(
48
+ layer, QKVParallelLinear):
49
+ self.output_sizes = layer.output_sizes
50
+
51
+ self.fuse_matmuls = get_model_matmul_fusion_assignment(
52
+ vllm_config.model_config.model,
53
+ vllm_config.scheduler_config.max_num_batched_tokens,
54
+ vllm_config.parallel_config.tensor_parallel_size,
55
+ layer._get_name())
56
+ elif isinstance(layer, ReplicatedLinear):
57
+ self.weight_sharding = P(None, None)
58
+ else:
59
+ logger.warning(
60
+ "Unsupported linear layer type of %s. Can potentially yield "
61
+ " bad performance.", type(layer))
62
+
63
+ self.bias_sharding = P(self.weight_sharding[0])
64
+ self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
65
+
66
+ def get_input_sharding(self, x: torchax.tensor.Tensor):
67
+ if self.enable_sequence_parallelism:
68
+ token_num = x.shape[0]
69
+ # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
70
+ if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
71
+ return self.input_sharding
72
+ else:
73
+ return None
74
+ return self.input_sharding
75
+
76
+ def get_output_sharding(self, x: torchax.tensor.Tensor):
77
+ if self.enable_sequence_parallelism:
78
+ token_num = x.shape[0]
79
+ # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
80
+ if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
81
+ return self.output_sharding
82
+ else:
83
+ return None
84
+ return self.output_sharding
85
+
86
+
87
+ class JaxCommonConfig:
88
+ vllm_config: VllmConfig
89
+ mesh: Mesh
90
+
91
+ @classmethod
92
+ def set_configs(cls, vllm_config: VllmConfig, mesh: Mesh):
93
+ cls.vllm_config = vllm_config
94
+ cls.mesh = mesh
95
+
96
+ def get_linear_config(self, layer: LinearBase) -> JaxCommonLinearConfig:
97
+ assert isinstance(layer, LinearBase)
98
+ return JaxCommonLinearConfig(self.vllm_config, self.mesh, layer)
99
+
100
+ def get_moe_config(self, layer: FusedMoE) -> FusedMoEConfig:
101
+ assert isinstance(layer, FusedMoE)
102
+ moe_config = layer.moe_config
103
+ use_ep = self.vllm_config.parallel_config.enable_expert_parallel
104
+ moe_config.moe_parallel_config.use_ep = use_ep
105
+ return moe_config
@@ -0,0 +1,120 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from jax.sharding import PartitionSpec
5
+ from vllm.attention.layer import Attention
6
+ from vllm.logger import init_logger
7
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
8
+ from vllm.model_executor.layers.linear import LinearBase
9
+ from vllm.model_executor.layers.quantization import \
10
+ register_quantization_config
11
+ from vllm.model_executor.layers.quantization.base_config import \
12
+ QuantizeMethodBase # noqa: E501
13
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
14
+ CompressedTensorsConfig, CompressedTensorsKVCacheMethod,
15
+ CompressedTensorsLinearMethod, CompressedTensorsScheme)
16
+ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
17
+ find_matched_target, should_ignore_layer)
18
+
19
+ from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
20
+ get_tpu_quant_method)
21
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
22
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
23
+ VllmCompressedTensorsW8A8Fp8MoEMethod
24
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
25
+ VllmCompressedTensorsW8A8Fp8
26
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
27
+ VllmCompressedTensorsW8A8Int8
28
+ from tpu_inference.layers.vllm.quantization.unquantized import \
29
+ VllmUnquantizedConfig
30
+
31
+ P = PartitionSpec
32
+ logger = init_logger(__name__)
33
+
34
+
35
+ @register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
36
+ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
37
+
38
+ @classmethod
39
+ def get_name(cls) -> str:
40
+ return COMPRESSED_TENSORS
41
+
42
+ def get_scheme(self,
43
+ layer: torch.nn.Module,
44
+ layer_name: Optional[str] = None
45
+ ) -> Optional["CompressedTensorsScheme"]:
46
+ """
47
+ compressed-tensors supports non uniform in the following way:
48
+
49
+ targets of config_groups: There can be N config_groups which each
50
+ have a quantization scheme. Each config_group has a list of targets
51
+ which can be a full layer_name, a regex for a layer_name, or
52
+ an nn.Module name.
53
+
54
+ Detect whether a layer_name is found in any target and
55
+ use the quantization scheme corresponding to the matched target
56
+ to select the CompressedTensorsScheme used for inference.
57
+ """
58
+
59
+ # Will be empty for models with only sparsity
60
+ weight_quant = input_quant = None
61
+ if self.target_scheme_map:
62
+ matched_target = find_matched_target(
63
+ layer_name=layer_name,
64
+ module=layer,
65
+ targets=self.target_scheme_map.keys(),
66
+ fused_mapping=self.packed_modules_mapping,
67
+ )
68
+
69
+ scheme_dict = self.target_scheme_map[matched_target]
70
+ weight_quant = scheme_dict.get("weights")
71
+ input_quant = scheme_dict.get("input_activations")
72
+
73
+ if weight_quant is None:
74
+ logger.warning_once("Acceleration for non-quantized schemes is "
75
+ "not supported by Compressed Tensors. "
76
+ "Falling back to UnquantizedLinearMethod")
77
+ return None
78
+
79
+ # TODO(kyuyeunk): Add support for different act_quant_format
80
+
81
+ linear_config = self.get_linear_config(layer)
82
+ if self._is_fp8_w8a8(weight_quant, input_quant):
83
+ is_static_input_scheme = input_quant and not input_quant.dynamic
84
+ return VllmCompressedTensorsW8A8Fp8(
85
+ weight_quant=weight_quant,
86
+ is_static_input_scheme=is_static_input_scheme,
87
+ jax_config=linear_config,
88
+ )
89
+ if self._is_dynamic_token_w8a8(weight_quant, input_quant):
90
+ return VllmCompressedTensorsW8A8Int8(
91
+ strategy=weight_quant.strategy,
92
+ is_static_input_scheme=False,
93
+ input_symmetric=input_quant.symmetric,
94
+ jax_config=linear_config,
95
+ )
96
+ raise NotImplementedError(
97
+ "No compressed-tensors compatible scheme was found.")
98
+
99
+ def get_quant_method(
100
+ self,
101
+ layer: torch.nn.Module,
102
+ prefix: str,
103
+ ) -> Optional[QuantizeMethodBase]:
104
+ if should_ignore_layer(prefix,
105
+ ignore=self.ignore,
106
+ fused_mapping=self.packed_modules_mapping):
107
+ return VllmUnquantizedConfig.get_quant_method(self, layer, prefix)
108
+ if isinstance(layer, LinearBase):
109
+ scheme = self.get_scheme(layer=layer, layer_name=prefix)
110
+ if scheme is None:
111
+ return VllmUnquantizedConfig.get_quant_method(
112
+ self, layer, prefix)
113
+ layer.scheme = scheme
114
+ return CompressedTensorsLinearMethod(self)
115
+ if isinstance(layer, FusedMoE):
116
+ return VllmCompressedTensorsW8A8Fp8MoEMethod(
117
+ self, layer.quant_config, self.mesh)
118
+ if isinstance(layer, Attention):
119
+ return CompressedTensorsKVCacheMethod(self)
120
+ return None
@@ -0,0 +1,203 @@
1
+ from typing import Callable, Optional, Union
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from jax.experimental.layout import Format, Layout
8
+ from jax.sharding import Mesh, NamedSharding
9
+ from jax.sharding import PartitionSpec as P
10
+ from torch.nn.parameter import Parameter
11
+ from torchax.interop import call_jax, torch_view
12
+ from torchax.ops.mappings import t2j
13
+ from vllm.logger import init_logger
14
+ from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
15
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
16
+ CompressedTensorsConfig
17
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \
18
+ CompressedTensorsW8A8Fp8MoEMethod
19
+ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
20
+ WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
21
+
22
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
28
+ JaxCommonConfig):
29
+
30
+ def __init__(self, quant_config: "CompressedTensorsConfig",
31
+ moe: FusedMoEConfig, mesh: Mesh):
32
+ super().__init__(quant_config, moe)
33
+ self.mesh = mesh
34
+ self.quant_config = quant_config
35
+
36
+ # disable GPU paths
37
+ self.use_marlin = False
38
+ self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
39
+ self.is_fp8_w8a8_sm100 = False
40
+ self.use_cutlass = False
41
+ self.disable_expert_map = False
42
+
43
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
44
+ assert isinstance(layer, FusedMoE)
45
+
46
+ intermediate_size = layer.w13_weight.shape[1] // 2
47
+ w1_weight = layer.w13_weight[:, :intermediate_size]
48
+ w3_weight = layer.w13_weight[:, intermediate_size:]
49
+ w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size]
50
+ w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:]
51
+
52
+ w2_weight = t2j(layer.w2_weight, use_dlpack=False)
53
+ w2_weight_scale = t2j(layer.w2_weight_scale.to(torch.bfloat16),
54
+ use_dlpack=False)
55
+ w1_weight = t2j(w1_weight, use_dlpack=False)
56
+ w1_weight_scale = t2j(w1_weight_scale.to(torch.bfloat16),
57
+ use_dlpack=False)
58
+ w3_weight = t2j(w3_weight, use_dlpack=False)
59
+ w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16),
60
+ use_dlpack=False)
61
+
62
+ if layer.use_ep:
63
+ format = Format(Layout((0, 1, 2)),
64
+ NamedSharding(self.mesh, P("model", None, None)))
65
+ w1_weight = jax.device_put(w1_weight, format)
66
+ w1_weight_scale = jax.device_put(w1_weight_scale, format)
67
+ w3_weight = jax.device_put(w3_weight, format)
68
+ w3_weight_scale = jax.device_put(w3_weight_scale, format)
69
+ w2_weight = jax.device_put(w2_weight, format)
70
+ w2_weight_scale = jax.device_put(w2_weight_scale, format)
71
+ else:
72
+ assert intermediate_size == w2_weight.shape[-1]
73
+ n_shards = self.mesh.shape["model"]
74
+ assert intermediate_size % n_shards == 0
75
+
76
+ # TODO: enable this if using fused weights
77
+ # output_sizes = [intermediate_size, intermediate_size]
78
+ # w13_weight = reorder_concatenated_tensor_for_sharding(
79
+ # w13_weight, output_sizes, n_shards, dim=1
80
+ # )
81
+
82
+ w13_format = Format(
83
+ Layout((0, 1, 2)),
84
+ NamedSharding(self.mesh, P(None, "model", None)))
85
+ w1_weight = jax.device_put(w1_weight, w13_format)
86
+ w1_weight_scale = jax.device_put(w1_weight_scale, w13_format)
87
+ w3_weight = jax.device_put(w3_weight, w13_format)
88
+ w3_weight_scale = jax.device_put(w3_weight_scale, w13_format)
89
+ w2_weight = jax.device_put(
90
+ w2_weight,
91
+ Format(Layout((0, 1, 2)),
92
+ NamedSharding(self.mesh, P(None, None, "model"))),
93
+ )
94
+ w2_weight_scale = jax.device_put(
95
+ w2_weight_scale,
96
+ Format(Layout((0, 1, 2)), NamedSharding(self.mesh, P())),
97
+ ) # replicate
98
+
99
+ w1_weight = Parameter(torch_view(w1_weight), requires_grad=False)
100
+ w1_weight_scale = Parameter(torch_view(w1_weight_scale),
101
+ requires_grad=False)
102
+ w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
103
+ w2_weight_scale = Parameter(torch_view(w2_weight_scale),
104
+ requires_grad=False)
105
+ w3_weight = Parameter(torch_view(w3_weight), requires_grad=False)
106
+ w3_weight_scale = Parameter(torch_view(w3_weight_scale),
107
+ requires_grad=False)
108
+
109
+ # TODO dont reuse variable
110
+ layer.w13_weight = w1_weight
111
+ layer.w13_weight_scale = w1_weight_scale
112
+ layer.w2_weight = w2_weight
113
+ layer.w2_weight_scale = w2_weight_scale
114
+ layer.w3_weight = w3_weight
115
+ layer.w3_weight_scale = w3_weight_scale
116
+
117
+ def apply(
118
+ self,
119
+ layer: torch.nn.Module,
120
+ x: torch.Tensor,
121
+ router_logits: torch.Tensor,
122
+ top_k: int,
123
+ renormalize: bool,
124
+ use_grouped_topk: bool = False,
125
+ topk_group: Optional[int] = None,
126
+ num_expert_group: Optional[int] = None,
127
+ global_num_experts: int = -1,
128
+ expert_map: Optional[torch.Tensor] = None,
129
+ custom_routing_function: Optional[Callable] = None,
130
+ scoring_func: str = "softmax",
131
+ routed_scaling_factor: float = 1.0,
132
+ e_score_correction_bias: Optional[torch.Tensor] = None,
133
+ apply_router_weight_on_input: bool = False,
134
+ activation: str = "silu",
135
+ enable_eplb: bool = False,
136
+ expert_load_view: Optional[torch.Tensor] = None,
137
+ logical_to_physical_map: Optional[torch.Tensor] = None,
138
+ logical_replica_count: Optional[torch.Tensor] = None,
139
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
140
+ assert isinstance(layer, FusedMoE)
141
+ if activation != "silu":
142
+ raise NotImplementedError(
143
+ "Only silu is supported for activation function.")
144
+ if scoring_func != "softmax":
145
+ raise NotImplementedError(
146
+ "Only softmax is supported for scoring_func")
147
+
148
+ # import sys
149
+ # sys.stdin = open(0)
150
+ # breakpoint()
151
+
152
+ # TODO: Use MoE kernel when it supports fp8
153
+
154
+ seqlen = x.shape[0]
155
+
156
+ expert_weights = F.softmax(router_logits, dim=-1)
157
+ expert_weights, expert_indices = torch.topk(expert_weights,
158
+ top_k,
159
+ dim=-1)
160
+ if renormalize:
161
+ expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
162
+
163
+ # cond ffn
164
+ # e = total num of exp = 160
165
+ # t = seqlen
166
+ # o = config.imtermediate size
167
+ # i = config.dim
168
+ #torch.einsum("ti, eoi -> teo", x, layer.w13_weight) * self.w13_weight_scale)
169
+ ux1 = call_jax(jax.lax.dot,
170
+ x,
171
+ layer.w13_weight,
172
+ dimension_numbers=(((1, ), (2, )), ((), ())),
173
+ preferred_element_type=jnp.bfloat16.dtype)
174
+ x1 = F.silu(ux1 * layer.w13_weight_scale.squeeze(2))
175
+
176
+ #x3 = torch.einsum("ti, eoi -> teo", x, layer.w3_weight) * self.w3_weight_scale
177
+ x3 = call_jax(jax.lax.dot,
178
+ x,
179
+ layer.w3_weight,
180
+ dimension_numbers=(((1, ), (2, )), ((), ())),
181
+ preferred_element_type=jnp.bfloat16.dtype
182
+ ) * layer.w3_weight_scale.squeeze(2)
183
+
184
+ #expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2_weight) * self.w2_weight_scale
185
+ expert_outs = call_jax(
186
+ jax.lax.dot,
187
+ x1 * x3,
188
+ layer.w2_weight,
189
+ dimension_numbers=(((2, ), (2, )), ((1, ), (0, ))),
190
+ preferred_element_type=jnp.bfloat16.dtype).transpose(
191
+ 0, 1) * layer.w2_weight_scale.squeeze(2)
192
+
193
+ seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
194
+ expert_outs = expert_outs[seq_indexes, expert_indices]
195
+
196
+ # out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
197
+ out = call_jax(jax.lax.dot,
198
+ expert_outs,
199
+ expert_weights,
200
+ dimension_numbers=(((1, ), (1, )), ((0, ), (0, ))),
201
+ preferred_element_type=jnp.bfloat16.dtype)
202
+
203
+ return out