tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,208 @@
1
+ from typing import Optional
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ from compressed_tensors.quantization import (QuantizationArgs,
7
+ QuantizationStrategy)
8
+ from jax.sharding import NamedSharding, PartitionSpec
9
+ from torchax.interop import jax_view, torch_view
10
+ from torchax.ops.mappings import t2j
11
+ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
12
+ CompressedTensorsW8A8Fp8
13
+ from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
14
+ per_tensor_dequantize
15
+
16
+ from tpu_inference.layers.vllm.linear_common import (
17
+ sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
18
+ torch_to_jax_param)
19
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
20
+
21
+ P = PartitionSpec
22
+
23
+
24
+ def requantize_with_max_scale(
25
+ weight: torch.Tensor, weight_scale: torch.Tensor,
26
+ logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
27
+ dtype = weight.dtype
28
+ dtype_info = torch.finfo(dtype)
29
+ maxval = float(dtype_info.max)
30
+ minval = float(dtype_info.min)
31
+
32
+ max_w_scale = weight_scale.max()
33
+
34
+ unfused_module_in_checkpoint = (weight_scale[-1]
35
+ > torch.finfo(torch.float8_e4m3fn).min)
36
+
37
+ # If unfused checkpoint, need requanize with the single scale.
38
+ if unfused_module_in_checkpoint:
39
+ start = 0
40
+ for idx, logical_width in enumerate(logical_widths):
41
+ # Skip any component with zero width.
42
+ if logical_width == 0:
43
+ continue
44
+ end = start + logical_width
45
+ weight_dq = per_tensor_dequantize(weight[start:end, :],
46
+ weight_scale[idx])
47
+ weight_q = weight_dq / max_w_scale
48
+ weight[start:end, :] = weight_q.clamp(minval, maxval).to(dtype)
49
+ start = end
50
+
51
+ return max_w_scale, weight
52
+
53
+
54
+ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
55
+
56
+ def __init__(
57
+ self,
58
+ weight_quant: QuantizationArgs,
59
+ is_static_input_scheme: bool,
60
+ jax_config: JaxCommonLinearConfig,
61
+ ):
62
+ super().__init__(weight_quant, is_static_input_scheme)
63
+
64
+ self.jax_config = jax_config
65
+
66
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
67
+ weight = layer.weight
68
+ weight_scale = layer.weight_scale
69
+
70
+ if self.is_static_input_scheme:
71
+ # In static quant, all input_scales share the same value.
72
+ assert layer.input_scale.min() == layer.input_scale.max()
73
+ input_scale_first = layer.input_scale[0]
74
+
75
+ input_scale = jax.device_put(
76
+ t2j(input_scale_first, use_dlpack=False),
77
+ NamedSharding(self.jax_config.mesh, P()))
78
+ input_scale = torch.nn.Parameter(torch_view(input_scale),
79
+ requires_grad=False)
80
+ delattr(layer, "input_scale")
81
+ layer.input_scale = input_scale
82
+
83
+ # TODO(kyuyeunk): Investigate performance gain from merging scales.
84
+ # By merging input and weight scales, we reduce the number of muls
85
+ # required for dequantization from 2 (for each scales) to 1.
86
+ # weight_scale *= input_scale_first
87
+
88
+ if self.strategy == QuantizationStrategy.TENSOR:
89
+ weight_scale, weight = requantize_with_max_scale(
90
+ weight, weight_scale, self.jax_config.output_sizes)
91
+ weight_scale = jax.device_put(
92
+ t2j(weight_scale, use_dlpack=False),
93
+ NamedSharding(self.jax_config.mesh, P()))
94
+ weight_scale = torch.nn.Parameter(torch_view(weight_scale),
95
+ requires_grad=False)
96
+ else:
97
+ weight_scale = weight_scale.squeeze(-1)
98
+ weight_scale = torch_to_jax_param(
99
+ weight_scale,
100
+ NamedSharding(self.jax_config.mesh,
101
+ self.jax_config.bias_sharding),
102
+ self.jax_config.output_sizes, self.jax_config.n_shards,
103
+ self.jax_config.fuse_matmuls)
104
+ delattr(layer, "weight_scale")
105
+ layer.weight_scale = weight_scale
106
+
107
+ weight = torch_to_jax_param(
108
+ layer.weight,
109
+ NamedSharding(self.jax_config.mesh,
110
+ self.jax_config.weight_sharding),
111
+ self.jax_config.output_sizes, self.jax_config.n_shards,
112
+ self.jax_config.fuse_matmuls)
113
+ delattr(layer, "weight")
114
+ layer.weight = weight
115
+
116
+ if layer.bias is not None:
117
+ bias = torch_to_jax_param(
118
+ layer.bias,
119
+ NamedSharding(self.jax_config.mesh,
120
+ self.jax_config.bias_sharding),
121
+ self.jax_config.output_sizes, self.jax_config.n_shards,
122
+ self.jax_config.fuse_matmuls)
123
+ delattr(layer, "bias")
124
+ layer.bias = bias
125
+
126
+ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
127
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
128
+ with jax.named_scope(layer._get_name()):
129
+ if self.jax_config.fuse_matmuls:
130
+ return self._apply_fused(layer, x, bias)
131
+ else:
132
+ return self._apply_split(layer, x, bias)
133
+
134
+ def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
135
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
136
+ x_jax = jax_view(x)
137
+ weight_jax = jax_view(layer.weight)
138
+ weight_scale_jax = jax_view(layer.weight_scale)
139
+
140
+ if self.is_static_input_scheme:
141
+ # TODO(kyuyeunk): Add kernel support for static quant
142
+ input_scale = jax_view(layer.input_scale)
143
+ dtype_info = jnp.finfo(weight_jax.dtype)
144
+ maxval = float(dtype_info.max)
145
+ minval = float(dtype_info.min)
146
+ x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
147
+ maxval).astype(weight_jax.dtype)
148
+
149
+ outs = jax.lax.dot_general(
150
+ x_q,
151
+ weight_jax,
152
+ (((1, ), (1, )), ((), ())),
153
+ preferred_element_type=jnp.float32,
154
+ )
155
+ outs *= weight_scale_jax
156
+ outs = outs.astype(x_jax.dtype)
157
+ else:
158
+ outs = sharded_quantized_matmul(x_jax, weight_jax,
159
+ weight_scale_jax,
160
+ self.jax_config.mesh,
161
+ self.jax_config.weight_sharding)
162
+
163
+ if bias is not None and not layer.skip_bias_add:
164
+ outs += jax_view(bias)
165
+ outs = slice_sharded_tensor_for_concatenation(
166
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
167
+ return torch_view(jnp.concatenate(outs, axis=-1))
168
+
169
+ def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
170
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
171
+ assert isinstance(layer.weight, torch.nn.ParameterList)
172
+
173
+ x_jax = jax_view(x)
174
+ outs = []
175
+ for i, (weight, weight_scale) in enumerate(
176
+ zip(layer.weight, layer.weight_scale)):
177
+ weight_jax = jax_view(weight)
178
+ weight_scale_jax = jax_view(weight_scale)
179
+
180
+ if self.is_static_input_scheme:
181
+ # TODO(kyuyeunk): Add kernel support for static quant
182
+ input_scale = jax_view(layer.input_scale)
183
+ dtype_info = jnp.finfo(weight_jax.dtype)
184
+ maxval = float(dtype_info.max)
185
+ minval = float(dtype_info.min)
186
+ x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
187
+ maxval).astype(weight_jax.dtype)
188
+
189
+ out = jax.lax.dot_general(
190
+ x_q,
191
+ weight_jax,
192
+ (((1, ), (1, )), ((), ())),
193
+ preferred_element_type=jnp.float32,
194
+ )
195
+ # TODO(kyuyeunk): Investigate performance gain from merging scales.
196
+ # out *= weight_scale_jax
197
+ out *= weight_scale_jax * input_scale
198
+ out = out.astype(x_jax.dtype)
199
+ else:
200
+ out = sharded_quantized_matmul(x_jax, weight_jax,
201
+ weight_scale_jax,
202
+ self.jax_config.mesh,
203
+ self.jax_config.weight_sharding)
204
+
205
+ if bias is not None and not layer.skip_bias_add:
206
+ out += jax_view(bias[i])
207
+ outs.append(out)
208
+ return torch_view(jnp.concatenate(outs, axis=-1))
@@ -0,0 +1,136 @@
1
+ from typing import Optional
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ from compressed_tensors.quantization import QuantizationStrategy
7
+ from jax.sharding import NamedSharding, PartitionSpec
8
+ from torchax.interop import jax_view, torch_view
9
+ from vllm.logger import init_logger
10
+ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
11
+ CompressedTensorsW8A8Int8
12
+ from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
13
+ convert_to_channelwise
14
+
15
+ from tpu_inference.layers.vllm.linear_common import (
16
+ sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
17
+ torch_to_jax_param)
18
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
19
+
20
+ P = PartitionSpec
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
25
+
26
+ def __init__(self, strategy: str, is_static_input_scheme: bool,
27
+ input_symmetric: bool, jax_config: JaxCommonLinearConfig):
28
+ super().__init__(strategy, is_static_input_scheme, input_symmetric)
29
+
30
+ self.jax_config = jax_config
31
+ self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL),
32
+
33
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
34
+ weight = torch_to_jax_param(
35
+ layer.weight,
36
+ NamedSharding(self.jax_config.mesh,
37
+ self.jax_config.weight_sharding),
38
+ self.jax_config.output_sizes,
39
+ self.jax_config.n_shards,
40
+ self.jax_config.fuse_matmuls,
41
+ )
42
+ delattr(layer, "weight")
43
+ layer.weight = weight
44
+
45
+ weight_scale = layer.weight_scale
46
+ is_fused_module = len(layer.logical_widths) > 1
47
+ if is_fused_module and not self.is_channelwise:
48
+ weight_scale = convert_to_channelwise(weight_scale,
49
+ layer.logical_widths)
50
+ weight_scale = weight_scale.squeeze(-1)
51
+
52
+ weight_scale = torch_to_jax_param(
53
+ weight_scale,
54
+ NamedSharding(self.jax_config.mesh, self.jax_config.bias_sharding),
55
+ self.jax_config.output_sizes,
56
+ self.jax_config.n_shards,
57
+ self.jax_config.fuse_matmuls,
58
+ )
59
+ delattr(layer, "weight_scale")
60
+ layer.weight_scale = weight_scale
61
+
62
+ if layer.bias is not None and not layer.skip_bias_add:
63
+ if layer.return_bias:
64
+ logger.warning_once("Bias might return incorrect value.")
65
+
66
+ bias = torch_to_jax_param(
67
+ layer.bias,
68
+ NamedSharding(self.jax_config.mesh,
69
+ self.jax_config.bias_sharding),
70
+ self.jax_config.output_sizes,
71
+ self.jax_config.n_shards,
72
+ self.jax_config.fuse_matmuls,
73
+ )
74
+ delattr(layer, "bias")
75
+ layer.bias = bias
76
+
77
+ # TODO(kyuyeunk): Support static range input quantization.
78
+ assert getattr(layer, "input_scale", None) is None
79
+ assert getattr(layer, "input_zero_point", None) is None
80
+ assert getattr(layer, "azp_adj", None) is None
81
+
82
+ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
83
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
84
+ with jax.named_scope(layer._get_name()):
85
+ if self.jax_config.fuse_matmuls:
86
+ out = self._apply_fused(layer, x, bias)
87
+ else:
88
+ out = self._apply_split(layer, x, bias)
89
+
90
+ return out
91
+
92
+ def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
93
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
94
+ x_jax = jax_view(x)
95
+ weight_jax = jax_view(layer.weight)
96
+ weight_scale_jax = jax_view(layer.weight_scale)
97
+
98
+ outs = sharded_quantized_matmul(
99
+ x_jax,
100
+ weight_jax,
101
+ weight_scale_jax,
102
+ self.jax_config.mesh,
103
+ self.jax_config.weight_sharding,
104
+ )
105
+ if bias is not None and not layer.skip_bias_add:
106
+ outs += jax_view(bias)
107
+
108
+ outs = slice_sharded_tensor_for_concatenation(
109
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
110
+ out = jnp.concatenate(outs, axis=-1)
111
+ return torch_view(out)
112
+
113
+ def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
114
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
115
+ assert isinstance(layer.weight, torch.nn.ParameterList)
116
+
117
+ x_jax = jax_view(x)
118
+ outs = []
119
+ for i, (weight, weight_scale) in enumerate(
120
+ zip(layer.weight, layer.weight_scale)):
121
+ weight_jax = jax_view(weight)
122
+ weight_scale_jax = jax_view(weight_scale)
123
+
124
+ out = sharded_quantized_matmul(
125
+ x_jax,
126
+ weight_jax,
127
+ weight_scale_jax,
128
+ self.jax_config.mesh,
129
+ self.jax_config.weight_sharding,
130
+ )
131
+ if bias is not None and not layer.skip_bias_add:
132
+ out += jax_view(bias[i])
133
+
134
+ outs.append(out)
135
+ out = jnp.concatenate(outs, axis=-1)
136
+ return torch_view(out)
@@ -0,0 +1,266 @@
1
+ from typing import Callable, Optional, Union
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ from jax.experimental.layout import Format, Layout
7
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
+ from torch.nn.parameter import Parameter
9
+ from torchax.interop import jax_view, torch_view
10
+ from torchax.ops.mappings import t2j
11
+ from vllm.logger import init_logger
12
+ from vllm.model_executor.layers.fused_moe.config import (
13
+ FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
14
+ from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
15
+ FusedMoEMethodBase)
16
+ from vllm.model_executor.layers.linear import LinearBase
17
+ from vllm.model_executor.layers.quantization import \
18
+ register_quantization_config
19
+ from vllm.model_executor.layers.quantization.base_config import \
20
+ QuantizeMethodBase
21
+ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
22
+ Mxfp4Config,
23
+ Mxfp4MoEMethod)
24
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
25
+ is_layer_skipped
26
+
27
+ from tpu_inference.layers.common.quant_methods import (MXFP4,
28
+ get_tpu_quant_method)
29
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
30
+ from tpu_inference.layers.vllm.linear_common import \
31
+ reorder_concatenated_tensor_for_sharding
32
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
33
+ from tpu_inference.layers.vllm.quantization.unquantized import \
34
+ VllmUnquantizedLinearMethod
35
+
36
+ MXFP4_BLOCK_SIZE = 32
37
+
38
+ P = PartitionSpec
39
+ logger = init_logger(__name__)
40
+
41
+
42
+ # TODO(kyuyeunk): Move these functions into a common utility file.
43
+ def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
44
+ assert u8_packed_e2m1.dtype == jnp.uint8
45
+ e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
46
+ # bitcast creates one more dimension that splits 8 bits into two e2m1.
47
+ # we flatten them with the last dim.
48
+ return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
49
+
50
+
51
+ def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
52
+ e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
53
+ exponents = u8.astype(jnp.int32) + e8_finfo.minexp
54
+ ones = jnp.ones_like(u8, dtype=jnp.float32)
55
+ return jnp.ldexp(ones, exponents)
56
+
57
+
58
+ def dequantize_block_weight(weight: jax.Array,
59
+ scale: jax.Array,
60
+ block_size: int,
61
+ out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
62
+ orig_shape = weight.shape
63
+ weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
64
+ weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
65
+ scale, -1)
66
+ return weight_dequantized.reshape(orig_shape).astype(out_dtype)
67
+
68
+
69
+ @register_quantization_config(get_tpu_quant_method(MXFP4))
70
+ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
71
+
72
+ @classmethod
73
+ def get_name(cls):
74
+ return MXFP4
75
+
76
+ def get_quant_method(self, layer: torch.nn.Module,
77
+ prefix: str) -> Optional["QuantizeMethodBase"]:
78
+ from vllm.attention.layer import Attention # Avoid circular import
79
+
80
+ if isinstance(layer, LinearBase):
81
+ linear_config = self.get_linear_config(layer)
82
+ if self.ignored_layers and is_layer_skipped(
83
+ prefix=prefix,
84
+ ignored_layers=self.ignored_layers,
85
+ fused_mapping=self.packed_modules_mapping,
86
+ ):
87
+ return VllmUnquantizedLinearMethod(linear_config)
88
+ # TODO: Add support for MXFP4 Linear Method.
89
+ # MXFP4 LinearMethod is available in AMD-Quark, refer to that
90
+ # implementation if you are interested in enabling MXFP4 here.
91
+ logger.warning_once(
92
+ "MXFP4 linear layer is not implemented - falling back to "
93
+ "UnquantizedLinearMethod.")
94
+ return VllmUnquantizedLinearMethod(linear_config)
95
+ elif isinstance(layer, FusedMoE):
96
+ return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
97
+ elif isinstance(layer, Attention):
98
+ # TODO: Add support for MXFP4 Attention.
99
+ logger.warning_once("MXFP4 attention layer is not implemented. "
100
+ "Skipping quantization for this layer.")
101
+ return None
102
+
103
+
104
+ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
105
+
106
+ def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
107
+ FusedMoEMethodBase.__init__(self, moe)
108
+
109
+ # We piggyback on triton implementation as it applies minimal hardware
110
+ # specific post processing to the weights.
111
+ self.mxfp4_backend = Mxfp4Backend.TRITON
112
+ self.mesh = mesh
113
+
114
+ def get_fused_moe_quant_config(
115
+ self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
116
+ # Because we have dequantized weights, we only need biased moe config.
117
+ # TODO(kyuyeunk): Add native support for MXFP4.
118
+ return biased_moe_quant_config(
119
+ layer.w13_bias,
120
+ layer.w2_bias,
121
+ )
122
+
123
+ def process_weights_after_loading(self, layer: torch.nn.Module):
124
+ assert isinstance(layer, FusedMoE)
125
+
126
+ w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
127
+ w13_weight_scale = e8m0_to_fp32(
128
+ t2j(layer.w13_weight_scale, use_dlpack=False))
129
+ w13_bias = t2j(layer.w13_bias, use_dlpack=False)
130
+
131
+ w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
132
+ w2_weight_scale = e8m0_to_fp32(
133
+ t2j(layer.w2_weight_scale, use_dlpack=False))
134
+ w2_bias = t2j(layer.w2_bias, use_dlpack=False)
135
+
136
+ # We dequantize fp4 weights into bf16.
137
+ # TODO(kyuyeunk): Add native support for MXFP4.
138
+ w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
139
+ MXFP4_BLOCK_SIZE, jnp.bfloat16)
140
+ w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
141
+ MXFP4_BLOCK_SIZE, jnp.bfloat16)
142
+
143
+ # Because we have dequantized weights, scales are not used anymore.
144
+ delattr(layer, "w13_weight_scale")
145
+ delattr(layer, "w2_weight_scale")
146
+
147
+ if layer.activation == "swigluoai":
148
+ # When using swigluoai, vLLM splits gmm output in a interleaved way.
149
+ # However, interleaved split is not performant on TPU. Therefore,
150
+ # we preprocess the weight so that splitting gmm output by middle
151
+ # can still get the same result.
152
+ w1_weight = w13_weight[:, ::2, :]
153
+ w3_weight = w13_weight[:, 1::2, :]
154
+ w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
155
+
156
+ w1_bias = w13_bias[:, ::2]
157
+ w3_bias = w13_bias[:, 1::2]
158
+ w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
159
+
160
+ # TODO(kyuyeunk): Add weight processing logic for the new kernel.
161
+ if layer.use_ep:
162
+ w13_weight = jax.device_put(
163
+ w13_weight,
164
+ Format(Layout((0, 1, 2)),
165
+ NamedSharding(self.mesh, P("model", None, None))))
166
+ w2_weight = jax.device_put(
167
+ w2_weight,
168
+ Format(Layout((0, 1, 2)),
169
+ NamedSharding(self.mesh, P("model", None, None))))
170
+
171
+ w13_bias = jax.device_put(
172
+ w13_bias,
173
+ Format(Layout((0, 1)),
174
+ NamedSharding(self.mesh, P("model", None))))
175
+ w2_bias = jax.device_put(
176
+ w2_bias,
177
+ Format(Layout((0, 1)),
178
+ NamedSharding(self.mesh, P("model", None))))
179
+
180
+ else:
181
+ intermediate_size = w13_weight.shape[1] // 2
182
+ assert intermediate_size == w2_weight.shape[-1]
183
+ output_sizes = [intermediate_size, intermediate_size]
184
+ n_shards = self.mesh.shape["model"]
185
+ assert intermediate_size % n_shards == 0
186
+ w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
187
+ output_sizes,
188
+ n_shards,
189
+ dim=1)
190
+ w13_weight = jax.device_put(
191
+ w13_weight,
192
+ Format(Layout((0, 1, 2)),
193
+ NamedSharding(self.mesh, P(None, "model", None))))
194
+ w2_weight = jax.device_put(
195
+ w2_weight,
196
+ Format(Layout((0, 1, 2)),
197
+ NamedSharding(self.mesh, P(None, None, "model"))))
198
+
199
+ w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
200
+ output_sizes,
201
+ n_shards,
202
+ dim=1)
203
+ w13_bias = jax.device_put(
204
+ w13_bias,
205
+ Format(Layout((0, 1)),
206
+ NamedSharding(self.mesh, P(None, "model"))))
207
+ w2_bias = jax.device_put(
208
+ w2_bias,
209
+ Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
210
+ None))))
211
+
212
+ layer.w13_weight = Parameter(torch_view(w13_weight),
213
+ requires_grad=False)
214
+ layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
215
+
216
+ layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
217
+ layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
218
+
219
+ pass
220
+
221
+ def apply(
222
+ self,
223
+ layer: torch.nn.Module,
224
+ x: torch.Tensor,
225
+ router_logits: torch.Tensor,
226
+ top_k: int,
227
+ renormalize: bool,
228
+ use_grouped_topk: bool = False,
229
+ topk_group: Optional[int] = None,
230
+ num_expert_group: Optional[int] = None,
231
+ global_num_experts: int = -1,
232
+ expert_map: Optional[torch.Tensor] = None,
233
+ custom_routing_function: Optional[Callable] = None,
234
+ scoring_func: str = "softmax",
235
+ routed_scaling_factor: float = 1.0,
236
+ e_score_correction_bias: Optional[torch.Tensor] = None,
237
+ apply_router_weight_on_input: bool = False,
238
+ activation: str = "silu",
239
+ enable_eplb: bool = False,
240
+ expert_load_view: Optional[torch.Tensor] = None,
241
+ logical_to_physical_map: Optional[torch.Tensor] = None,
242
+ logical_replica_count: Optional[torch.Tensor] = None,
243
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
244
+ assert isinstance(layer, FusedMoE)
245
+ if scoring_func != "softmax":
246
+ raise NotImplementedError(
247
+ "Only softmax is supported for scoring_func")
248
+
249
+ # Use the original implementation
250
+ output = fused_moe_func_padded(
251
+ jax_view(x),
252
+ jax_view(layer.w13_weight),
253
+ jax_view(layer.w2_weight),
254
+ jax_view(layer.w13_bias) if self.moe.has_bias else None,
255
+ jax_view(layer.w2_bias) if self.moe.has_bias else None,
256
+ jax_view(router_logits),
257
+ topk=top_k,
258
+ global_num_experts=global_num_experts,
259
+ renormalize=renormalize,
260
+ reduce_results=layer.reduce_results,
261
+ mesh=self.mesh,
262
+ use_ep=layer.use_ep,
263
+ activation=activation,
264
+ )
265
+
266
+ return torch_view(output)