tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,150 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import torch
20
+ from compressed_tensors.quantization import QuantizationStrategy
21
+ from jax.sharding import NamedSharding, PartitionSpec
22
+ from torchax.interop import jax_view, torch_view
23
+ from vllm.logger import init_logger
24
+ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
25
+ CompressedTensorsW8A8Int8
26
+ from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
27
+ convert_to_channelwise
28
+
29
+ from tpu_inference.layers.vllm.linear_common import (
30
+ sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
31
+ torch_to_jax_param)
32
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
33
+
34
+ P = PartitionSpec
35
+ logger = init_logger(__name__)
36
+
37
+
38
+ class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
39
+
40
+ def __init__(self, strategy: str, is_static_input_scheme: bool,
41
+ input_symmetric: bool, jax_config: JaxCommonLinearConfig):
42
+ super().__init__(strategy, is_static_input_scheme, input_symmetric)
43
+
44
+ self.jax_config = jax_config
45
+ self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL),
46
+
47
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
48
+ weight = torch_to_jax_param(
49
+ layer.weight,
50
+ NamedSharding(self.jax_config.mesh,
51
+ self.jax_config.weight_sharding),
52
+ self.jax_config.output_sizes,
53
+ self.jax_config.n_shards,
54
+ self.jax_config.fuse_matmuls,
55
+ )
56
+ delattr(layer, "weight")
57
+ layer.weight = weight
58
+
59
+ weight_scale = layer.weight_scale
60
+ is_fused_module = len(layer.logical_widths) > 1
61
+ if is_fused_module and not self.is_channelwise:
62
+ weight_scale = convert_to_channelwise(weight_scale,
63
+ layer.logical_widths)
64
+ weight_scale = weight_scale.squeeze(-1)
65
+
66
+ weight_scale = torch_to_jax_param(
67
+ weight_scale,
68
+ NamedSharding(self.jax_config.mesh, self.jax_config.bias_sharding),
69
+ self.jax_config.output_sizes,
70
+ self.jax_config.n_shards,
71
+ self.jax_config.fuse_matmuls,
72
+ )
73
+ delattr(layer, "weight_scale")
74
+ layer.weight_scale = weight_scale
75
+
76
+ if layer.bias is not None and not layer.skip_bias_add:
77
+ if layer.return_bias:
78
+ logger.warning_once("Bias might return incorrect value.")
79
+
80
+ bias = torch_to_jax_param(
81
+ layer.bias,
82
+ NamedSharding(self.jax_config.mesh,
83
+ self.jax_config.bias_sharding),
84
+ self.jax_config.output_sizes,
85
+ self.jax_config.n_shards,
86
+ self.jax_config.fuse_matmuls,
87
+ )
88
+ delattr(layer, "bias")
89
+ layer.bias = bias
90
+
91
+ # TODO(kyuyeunk): Support static range input quantization.
92
+ assert getattr(layer, "input_scale", None) is None
93
+ assert getattr(layer, "input_zero_point", None) is None
94
+ assert getattr(layer, "azp_adj", None) is None
95
+
96
+ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
97
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
98
+ with jax.named_scope(layer._get_name()):
99
+ if self.jax_config.fuse_matmuls:
100
+ out = self._apply_fused(layer, x, bias)
101
+ else:
102
+ out = self._apply_split(layer, x, bias)
103
+
104
+ return out
105
+
106
+ def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
107
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
108
+ x_jax = jax_view(x)
109
+ weight_jax = jax_view(layer.weight)
110
+ weight_scale_jax = jax_view(layer.weight_scale)
111
+
112
+ outs = sharded_quantized_matmul(
113
+ x_jax,
114
+ weight_jax,
115
+ weight_scale_jax,
116
+ self.jax_config.mesh,
117
+ self.jax_config.weight_sharding,
118
+ )
119
+ if bias is not None and not layer.skip_bias_add:
120
+ outs += jax_view(bias)
121
+
122
+ outs = slice_sharded_tensor_for_concatenation(
123
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
124
+ out = jnp.concatenate(outs, axis=-1)
125
+ return torch_view(out)
126
+
127
+ def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
128
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
129
+ assert isinstance(layer.weight, torch.nn.ParameterList)
130
+
131
+ x_jax = jax_view(x)
132
+ outs = []
133
+ for i, (weight, weight_scale) in enumerate(
134
+ zip(layer.weight, layer.weight_scale)):
135
+ weight_jax = jax_view(weight)
136
+ weight_scale_jax = jax_view(weight_scale)
137
+
138
+ out = sharded_quantized_matmul(
139
+ x_jax,
140
+ weight_jax,
141
+ weight_scale_jax,
142
+ self.jax_config.mesh,
143
+ self.jax_config.weight_sharding,
144
+ )
145
+ if bias is not None and not layer.skip_bias_add:
146
+ out += jax_view(bias[i])
147
+
148
+ outs.append(out)
149
+ out = jnp.concatenate(outs, axis=-1)
150
+ return torch_view(out)
@@ -0,0 +1,118 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union
16
+
17
+ import jax
18
+ import torch
19
+ from jax.sharding import PartitionSpec
20
+ from vllm.logger import init_logger
21
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
22
+ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
23
+ from vllm.model_executor.layers.quantization import \
24
+ register_quantization_config
25
+ from vllm.model_executor.layers.quantization.base_config import \
26
+ QuantizeMethodBase
27
+ from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
28
+ Fp8LinearMethod)
29
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
30
+ is_layer_skipped
31
+
32
+ from tpu_inference.layers.common.quant_methods import FP8, get_tpu_quant_method
33
+ from tpu_inference.layers.vllm.quantization.common import (
34
+ JaxCommonConfig, JaxCommonLinearConfig)
35
+ from tpu_inference.layers.vllm.quantization.unquantized import \
36
+ VllmUnquantizedLinearMethod
37
+
38
+ P = PartitionSpec
39
+ logger = init_logger(__name__)
40
+
41
+
42
+ @register_quantization_config(get_tpu_quant_method(FP8))
43
+ class VllmFp8Config(Fp8Config, JaxCommonConfig):
44
+
45
+ @classmethod
46
+ def get_name(cls):
47
+ return FP8
48
+
49
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
50
+ return [torch.bfloat16]
51
+
52
+ def get_quant_method(
53
+ self, layer: torch.nn.Module, prefix: str
54
+ ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
55
+ if isinstance(layer, LinearBase):
56
+ linear_config = self.get_linear_config(layer)
57
+ if is_layer_skipped(prefix, self.ignored_layers):
58
+ return VllmUnquantizedLinearMethod(linear_config)
59
+ return VllmFp8LinearMethod(self, linear_config)
60
+ elif isinstance(layer, FusedMoE):
61
+ raise NotImplementedError(
62
+ "FP8 FusedMoE is currently not supported in torchax-jax")
63
+ return None
64
+
65
+
66
+ class VllmFp8LinearMethod(Fp8LinearMethod):
67
+
68
+ def __init__(self, quant_config: VllmFp8Config,
69
+ jax_config: JaxCommonLinearConfig):
70
+ super().__init__(quant_config)
71
+ self.jax_config = jax_config
72
+ self._configure_sharding()
73
+
74
+ def _configure_sharding(self) -> None:
75
+
76
+ raise NotImplementedError(
77
+ "Configure PartitionSpec for weight_sharding and scale_sharding "
78
+ "based on layer type (RowParallel/ColumnParallel)")
79
+
80
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
81
+
82
+ raise NotImplementedError(
83
+ "Convert layer.weight, layer.weight_scale, and optionally "
84
+ "layer.input_scale and layer.bias from torch tensors to JAX arrays "
85
+ "using torch_to_jax_param() with appropriate sharding")
86
+
87
+ def apply(self,
88
+ layer: torch.nn.Module,
89
+ x: torch.Tensor,
90
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
91
+
92
+ with jax.named_scope(layer._get_name()):
93
+ if self.jax_config.fuse_matmuls:
94
+ out = self._apply_fused(layer, x, bias)
95
+ else:
96
+ out = self._apply_split(layer, x, bias)
97
+
98
+ return out
99
+
100
+ def _apply_fused(self,
101
+ layer: torch.nn.Module,
102
+ x: torch.Tensor,
103
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
104
+
105
+ raise NotImplementedError(
106
+ "Implement single matmul for fused outputs: "
107
+ "quantize input to fp8, perform fp8 matmul with weight and scales, "
108
+ "dequantize output, and add bias if present")
109
+
110
+ def _apply_split(self,
111
+ layer: torch.nn.Module,
112
+ x: torch.Tensor,
113
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
114
+
115
+ raise NotImplementedError(
116
+ "Implement separate matmuls per output partition: "
117
+ "split weight/scale by output_sizes, perform fp8 matmul for each, "
118
+ "concatenate results, and add bias if present")
@@ -0,0 +1,396 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import torch
20
+ from jax.experimental.layout import Format, Layout
21
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
22
+ from torch.nn.parameter import Parameter
23
+ from torchax.interop import jax_view, torch_view
24
+ from torchax.ops.mappings import t2j
25
+ from vllm.logger import init_logger
26
+ from vllm.model_executor.layers.fused_moe.config import (
27
+ FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
28
+ from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
29
+ FusedMoEMethodBase)
30
+ from vllm.model_executor.layers.linear import LinearBase
31
+ from vllm.model_executor.layers.quantization import \
32
+ register_quantization_config
33
+ from vllm.model_executor.layers.quantization.base_config import \
34
+ QuantizeMethodBase
35
+ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
36
+ Mxfp4Config,
37
+ Mxfp4MoEMethod)
38
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
39
+ is_layer_skipped
40
+
41
+ from tpu_inference import envs
42
+ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
43
+ from tpu_inference.layers.common.quant_methods import (MXFP4,
44
+ get_tpu_quant_method)
45
+ from tpu_inference.layers.common.quantization import (
46
+ dequantize_tensor_from_mxfp4_packed, quantize_tensor)
47
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func
48
+ from tpu_inference.layers.vllm.linear_common import \
49
+ reorder_concatenated_tensor_for_sharding
50
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
51
+ from tpu_inference.layers.vllm.quantization.unquantized import \
52
+ VllmUnquantizedLinearMethod
53
+
54
+ REQUANTIZED_BLOCK_SIZE = 512
55
+
56
+ P = PartitionSpec
57
+
58
+ logger = init_logger(__name__)
59
+
60
+
61
+ @register_quantization_config(get_tpu_quant_method(MXFP4))
62
+ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
63
+
64
+ @classmethod
65
+ def get_name(cls):
66
+ return MXFP4
67
+
68
+ def get_quant_method(self, layer: torch.nn.Module,
69
+ prefix: str) -> Optional["QuantizeMethodBase"]:
70
+ from vllm.attention.layer import Attention # Avoid circular import
71
+
72
+ if isinstance(layer, LinearBase):
73
+ linear_config = self.get_linear_config(layer)
74
+ if self.ignored_layers and is_layer_skipped(
75
+ prefix=prefix,
76
+ ignored_layers=self.ignored_layers,
77
+ fused_mapping=self.packed_modules_mapping,
78
+ ):
79
+ return VllmUnquantizedLinearMethod(linear_config)
80
+ logger.warning_once(
81
+ "MXFP4 linear layer is not implemented - falling back to "
82
+ "UnquantizedLinearMethod.")
83
+ return VllmUnquantizedLinearMethod(linear_config)
84
+ elif isinstance(layer, FusedMoE):
85
+ moe_config = self.get_moe_config(layer)
86
+ return VllmMxfp4MoEMethod(moe_config, self.mesh)
87
+ elif isinstance(layer, Attention):
88
+ logger.warning_once("MXFP4 attention layer is not implemented. "
89
+ "Skipping quantization for this layer.")
90
+ return None
91
+
92
+
93
+ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
94
+
95
+ def __init__(self,
96
+ moe: FusedMoEConfig,
97
+ mesh: Mesh,
98
+ ep_axis_name: str = 'model'):
99
+ FusedMoEMethodBase.__init__(self, moe)
100
+
101
+ # We piggyback on triton implementation as it applies minimal hardware
102
+ # specific post processing to the weights.
103
+ self.mxfp4_backend = Mxfp4Backend.TRITON
104
+
105
+ self.mesh = mesh
106
+ self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
107
+ self.ep_axis_name = ep_axis_name
108
+ # TODO: Use autotune table once we have it.
109
+ self.block_size = {
110
+ "bt": 256,
111
+ "bf": 1024,
112
+ "bd1": 1024,
113
+ "bd2": 1024,
114
+ "btc": 256,
115
+ "bfc": 1024,
116
+ "bd1c": 1024,
117
+ "bd2c": 1024,
118
+ }
119
+
120
+ def get_fused_moe_quant_config(
121
+ self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
122
+ return mxfp4_w4a16_moe_quant_config(
123
+ w1_scale=layer.w13_weight_scale,
124
+ w2_scale=layer.w2_weight_scale,
125
+ w1_bias=layer.w13_bias,
126
+ w2_bias=layer.w2_bias,
127
+ )
128
+
129
+ def process_weights_after_loading(self, layer: torch.nn.Module):
130
+ assert isinstance(layer, FusedMoE)
131
+ assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
132
+
133
+ w13_weight = t2j(layer.w13_weight, use_dlpack=False)
134
+ w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
135
+ w13_bias = t2j(layer.w13_bias, use_dlpack=False)
136
+
137
+ w2_weight = t2j(layer.w2_weight, use_dlpack=False)
138
+ w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
139
+ w2_bias = t2j(layer.w2_bias, use_dlpack=False)
140
+
141
+ # Wrap functions in jit to speedup requantization.
142
+ @jax.jit
143
+ def wrapper(w13_weight, w13_weight_scale, w13_bias, w2_weight,
144
+ w2_weight_scale, w2_bias):
145
+ # Dequantize fp4 weights into fp32.
146
+ w13_weight = dequantize_tensor_from_mxfp4_packed(
147
+ w13_weight, w13_weight_scale, 2)
148
+ w2_weight = dequantize_tensor_from_mxfp4_packed(
149
+ w2_weight, w2_weight_scale, 2)
150
+
151
+ num_experts, orig_hidden_size, orig_intermediate_size = w2_weight.shape
152
+
153
+ # Requantize the weights into TPU friendly block size.
154
+ w13_weight, w13_weight_scale = quantize_tensor(
155
+ jnp.float4_e2m1fn, w13_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
156
+ w2_weight, w2_weight_scale = quantize_tensor(
157
+ jnp.float4_e2m1fn, w2_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
158
+
159
+ intermediate_size = w2_weight.shape[-1]
160
+ hidden_size = w13_weight.shape[-1]
161
+
162
+ # Dims may have been padded to align with subchannel size during
163
+ # quantization. We pad the corresponding dim on other weight.
164
+ # NOTE: We perform padding after quantization as padding value can
165
+ # affect quantization numerics.
166
+ intermediate_padding_size = 2 * (intermediate_size -
167
+ orig_intermediate_size)
168
+ w13_weight = jnp.pad(w13_weight,
169
+ ((0, 0), (0, intermediate_padding_size),
170
+ (0, 0)))
171
+ w13_weight_scale = jnp.pad(w13_weight_scale,
172
+ ((0, 0), (0, intermediate_padding_size),
173
+ (0, 0)))
174
+ w13_bias = jnp.pad(w13_bias,
175
+ ((0, 0), (0, intermediate_padding_size)))
176
+
177
+ hidden_padding_size = hidden_size - orig_hidden_size
178
+ w2_weight = jnp.pad(w2_weight,
179
+ ((0, 0), (0, hidden_padding_size), (0, 0)))
180
+ w2_weight_scale = jnp.pad(w2_weight_scale,
181
+ ((0, 0), (0, hidden_padding_size),
182
+ (0, 0)))
183
+ w2_bias = jnp.pad(w2_bias, ((0, 0), (0, hidden_padding_size)))
184
+
185
+ if layer.activation == "swigluoai":
186
+ # When using swigluoai, vLLM splits gmm output in a interleaved way.
187
+ # However, interleaved split is not performant on TPU. Therefore,
188
+ # we preprocess the weight so that splitting gmm output by middle
189
+ # can still get the same result.
190
+ w1_weight = w13_weight[:, ::2, :]
191
+ w3_weight = w13_weight[:, 1::2, :]
192
+ w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
193
+
194
+ w1_weight_scale = w13_weight_scale[:, ::2, :]
195
+ w3_weight_scale = w13_weight_scale[:, 1::2, :]
196
+ w13_weight_scale = jnp.concat(
197
+ [w1_weight_scale, w3_weight_scale], axis=1)
198
+
199
+ w1_bias = w13_bias[:, ::2]
200
+ w3_bias = w13_bias[:, 1::2]
201
+ w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
202
+
203
+ if self.use_kernel:
204
+ # Kernel expects:
205
+ # w13: (num_experts, 2, hidden_size, intermediate_size)
206
+ # w2: (num_experts, intermediate_size, hidden_size)
207
+ # Current format:
208
+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
209
+ # w2_weight: (num_experts, hidden_size, intermediate_size)
210
+
211
+ w13_weight = w13_weight.reshape(num_experts, 2,
212
+ intermediate_size, hidden_size)
213
+
214
+ w13_weight_scale = w13_weight_scale.reshape(
215
+ num_experts, 2, intermediate_size, 1, -1)
216
+ w2_weight_scale = w2_weight_scale.reshape(
217
+ num_experts, hidden_size, 1, -1)
218
+
219
+ w13_bias = w13_bias.astype(jnp.float32).reshape(
220
+ num_experts, 2, 1, intermediate_size)
221
+ w2_bias = w2_bias.astype(jnp.float32).reshape(
222
+ num_experts, 1, hidden_size)
223
+
224
+ # Transpose non-constracting dim to right most dim
225
+ w13_weight = jnp.swapaxes(w13_weight, 2, 3)
226
+ w2_weight = jnp.swapaxes(w2_weight, 1, 2)
227
+
228
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 2, 4)
229
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 3)
230
+
231
+ # Apply EP sharding
232
+ ep_sharding = NamedSharding(self.mesh, P("model"))
233
+
234
+ w13_weight = jax.lax.with_sharding_constraint(
235
+ w13_weight, Format(Layout((0, 1, 2, 3)), ep_sharding))
236
+ w2_weight = jax.lax.with_sharding_constraint(
237
+ w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
238
+
239
+ w13_weight_scale = jax.lax.with_sharding_constraint(
240
+ w13_weight_scale,
241
+ Format(Layout((0, 1, 2, 3, 4)), ep_sharding))
242
+ w2_weight_scale = jax.lax.with_sharding_constraint(
243
+ w2_weight_scale, Format(Layout((0, 1, 2, 3)), ep_sharding))
244
+
245
+ w13_bias = jax.lax.with_sharding_constraint(
246
+ w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
247
+ w2_bias = jax.lax.with_sharding_constraint(
248
+ w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
249
+ else:
250
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
251
+ w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
252
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
253
+ w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
254
+
255
+ w13_bias = jnp.expand_dims(w13_bias, 1)
256
+ w2_bias = jnp.expand_dims(w2_bias, 1)
257
+
258
+ if layer.use_ep:
259
+ ep_sharding = NamedSharding(self.mesh, P("model"))
260
+
261
+ w13_weight = jax.lax.with_sharding_constraint(
262
+ w13_weight, ep_sharding)
263
+ w2_weight = jax.lax.with_sharding_constraint(
264
+ w2_weight, ep_sharding)
265
+
266
+ w13_weight_scale = jax.lax.with_sharding_constraint(
267
+ w13_weight_scale, ep_sharding)
268
+ w2_weight_scale = jax.lax.with_sharding_constraint(
269
+ w2_weight_scale, ep_sharding)
270
+
271
+ w13_bias = jax.lax.with_sharding_constraint(
272
+ w13_bias, ep_sharding)
273
+ w2_bias = jax.lax.with_sharding_constraint(
274
+ w2_bias, ep_sharding)
275
+
276
+ else:
277
+ output_sizes = [intermediate_size, intermediate_size]
278
+ n_shards = self.mesh.shape["model"]
279
+ assert intermediate_size % n_shards == 0
280
+
281
+ # Reorder w13 weights so that splitting w1 and w3 output
282
+ # can happen locally without any collective operations.
283
+ w13_weight = reorder_concatenated_tensor_for_sharding(
284
+ w13_weight,
285
+ output_sizes,
286
+ n_shards,
287
+ dim=1,
288
+ )
289
+ w13_weight_scale = reorder_concatenated_tensor_for_sharding(
290
+ w13_weight_scale,
291
+ output_sizes,
292
+ n_shards,
293
+ dim=3,
294
+ )
295
+ w13_bias = reorder_concatenated_tensor_for_sharding(
296
+ w13_bias,
297
+ output_sizes,
298
+ n_shards,
299
+ dim=2,
300
+ )
301
+
302
+ w13_weight = jax.lax.with_sharding_constraint(
303
+ w13_weight,
304
+ NamedSharding(self.mesh, P(None, "model", None)))
305
+ w2_weight = jax.lax.with_sharding_constraint(
306
+ w2_weight,
307
+ NamedSharding(self.mesh, P(None, None, "model")))
308
+ w13_weight_scale = jax.lax.with_sharding_constraint(
309
+ w13_weight_scale,
310
+ NamedSharding(self.mesh, P(None, None, None, "model")))
311
+ w2_weight_scale = jax.lax.with_sharding_constraint(
312
+ w2_weight_scale,
313
+ NamedSharding(self.mesh, P(None, "model", None, None)))
314
+ w13_bias = jax.lax.with_sharding_constraint(
315
+ w13_bias,
316
+ NamedSharding(self.mesh, P(None, None, "model")))
317
+ w2_bias = jax.lax.with_sharding_constraint(
318
+ w2_bias, NamedSharding(self.mesh, P(None, None, None)))
319
+
320
+ return w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias
321
+
322
+ w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias = wrapper(
323
+ w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale,
324
+ w2_bias)
325
+
326
+ layer.w13_weight = Parameter(torch_view(w13_weight),
327
+ requires_grad=False)
328
+ layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
329
+
330
+ layer.w13_weight_scale = Parameter(torch_view(w13_weight_scale),
331
+ requires_grad=False)
332
+ layer.w2_weight_scale = Parameter(torch_view(w2_weight_scale),
333
+ requires_grad=False)
334
+
335
+ layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
336
+ layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
337
+
338
+ def apply(
339
+ self,
340
+ layer: torch.nn.Module,
341
+ x: torch.Tensor,
342
+ router_logits: torch.Tensor,
343
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
344
+ assert isinstance(layer, FusedMoE)
345
+ if layer.scoring_func != "softmax":
346
+ raise NotImplementedError(
347
+ "Only softmax is supported for scoring_func")
348
+
349
+ x = jax_view(x)
350
+ w13_weight = jax_view(layer.w13_weight)
351
+ w2_weight = jax_view(layer.w2_weight)
352
+ w13_weight_scale = jax_view(layer.w13_weight_scale)
353
+ w2_weight_scale = jax_view(layer.w2_weight_scale)
354
+ w13_bias = jax_view(layer.w13_bias)
355
+ w2_bias = jax_view(layer.w2_bias)
356
+ gating_output = jax_view(router_logits)
357
+
358
+ if self.use_kernel:
359
+ actual_hidden_size = x.shape[-1]
360
+ padding_size = w13_weight.shape[-2] - actual_hidden_size
361
+ x = jnp.pad(x, ((0, 0), (0, padding_size)))
362
+ output = fused_ep_moe(
363
+ mesh=self.mesh,
364
+ tokens=x,
365
+ w1=w13_weight,
366
+ w2=w2_weight,
367
+ w1_scale=w13_weight_scale,
368
+ w2_scale=w2_weight_scale,
369
+ b1=w13_bias,
370
+ b2=w2_bias,
371
+ gating_output=gating_output,
372
+ subc_quant_wsz=REQUANTIZED_BLOCK_SIZE,
373
+ top_k=layer.top_k,
374
+ ep_axis_name=self.ep_axis_name,
375
+ renormalize_topk_logits=layer.renormalize,
376
+ act_fn=layer.activation,
377
+ **self.block_size,
378
+ )[:, :actual_hidden_size]
379
+ else:
380
+ output = fused_moe_func(
381
+ hidden_states=x,
382
+ w1=w13_weight,
383
+ w2=w2_weight,
384
+ w1_scale=w13_weight_scale,
385
+ w2_scale=w2_weight_scale,
386
+ w1_bias=w13_bias,
387
+ w2_bias=w2_bias,
388
+ gating_output=gating_output,
389
+ topk=layer.top_k,
390
+ renormalize=layer.renormalize,
391
+ mesh=self.mesh,
392
+ use_ep=layer.use_ep,
393
+ activation=layer.activation,
394
+ )
395
+
396
+ return torch_view(output)