tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,135 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from jax.sharding import PartitionSpec
19
+ from vllm.attention.layer import Attention
20
+ from vllm.logger import init_logger
21
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
22
+ from vllm.model_executor.layers.linear import LinearBase
23
+ from vllm.model_executor.layers.quantization import \
24
+ register_quantization_config
25
+ from vllm.model_executor.layers.quantization.base_config import \
26
+ QuantizeMethodBase # noqa: E501
27
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
28
+ CompressedTensorsConfig, CompressedTensorsKVCacheMethod,
29
+ CompressedTensorsLinearMethod, CompressedTensorsScheme)
30
+ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
31
+ find_matched_target, should_ignore_layer)
32
+
33
+ from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
34
+ get_tpu_quant_method)
35
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
36
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
37
+ VllmCompressedTensorsMoEMethod
38
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
39
+ VllmCompressedTensorsW8A8Fp8
40
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
41
+ VllmCompressedTensorsW8A8Int8
42
+ from tpu_inference.layers.vllm.quantization.unquantized import \
43
+ VllmUnquantizedConfig
44
+
45
+ P = PartitionSpec
46
+ logger = init_logger(__name__)
47
+
48
+
49
+ @register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
50
+ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
51
+
52
+ @classmethod
53
+ def get_name(cls) -> str:
54
+ return COMPRESSED_TENSORS
55
+
56
+ def get_scheme(self,
57
+ layer: torch.nn.Module,
58
+ layer_name: Optional[str] = None
59
+ ) -> Optional["CompressedTensorsScheme"]:
60
+ """
61
+ compressed-tensors supports non uniform in the following way:
62
+
63
+ targets of config_groups: There can be N config_groups which each
64
+ have a quantization scheme. Each config_group has a list of targets
65
+ which can be a full layer_name, a regex for a layer_name, or
66
+ an nn.Module name.
67
+
68
+ Detect whether a layer_name is found in any target and
69
+ use the quantization scheme corresponding to the matched target
70
+ to select the CompressedTensorsScheme used for inference.
71
+ """
72
+
73
+ # Will be empty for models with only sparsity
74
+ weight_quant = input_quant = None
75
+ if self.target_scheme_map:
76
+ matched_target = find_matched_target(
77
+ layer_name=layer_name,
78
+ module=layer,
79
+ targets=self.target_scheme_map.keys(),
80
+ fused_mapping=self.packed_modules_mapping,
81
+ )
82
+
83
+ scheme_dict = self.target_scheme_map[matched_target]
84
+ weight_quant = scheme_dict.get("weights")
85
+ input_quant = scheme_dict.get("input_activations")
86
+
87
+ if weight_quant is None:
88
+ logger.warning_once("Acceleration for non-quantized schemes is "
89
+ "not supported by Compressed Tensors. "
90
+ "Falling back to UnquantizedLinearMethod")
91
+ return None
92
+
93
+ # TODO(kyuyeunk): Add support for different act_quant_format
94
+
95
+ linear_config = self.get_linear_config(layer)
96
+ if self._is_fp8_w8a8(weight_quant, input_quant):
97
+ is_static_input_scheme = input_quant and not input_quant.dynamic
98
+ return VllmCompressedTensorsW8A8Fp8(
99
+ weight_quant=weight_quant,
100
+ is_static_input_scheme=is_static_input_scheme,
101
+ jax_config=linear_config,
102
+ )
103
+ if self._is_dynamic_token_w8a8(weight_quant, input_quant):
104
+ return VllmCompressedTensorsW8A8Int8(
105
+ strategy=weight_quant.strategy,
106
+ is_static_input_scheme=False,
107
+ input_symmetric=input_quant.symmetric,
108
+ jax_config=linear_config,
109
+ )
110
+ raise NotImplementedError(
111
+ "No compressed-tensors compatible scheme was found.")
112
+
113
+ def get_quant_method(
114
+ self,
115
+ layer: torch.nn.Module,
116
+ prefix: str,
117
+ ) -> Optional[QuantizeMethodBase]:
118
+ if should_ignore_layer(prefix,
119
+ ignore=self.ignore,
120
+ fused_mapping=self.packed_modules_mapping):
121
+ return VllmUnquantizedConfig.get_quant_method(self, layer, prefix)
122
+ if isinstance(layer, LinearBase):
123
+ scheme = self.get_scheme(layer=layer, layer_name=prefix)
124
+ if scheme is None:
125
+ return VllmUnquantizedConfig.get_quant_method(
126
+ self, layer, prefix)
127
+ layer.scheme = scheme
128
+ return CompressedTensorsLinearMethod(self)
129
+ if isinstance(layer, FusedMoE):
130
+ layer.moe_config = self.get_moe_config(layer)
131
+ return VllmCompressedTensorsMoEMethod.get_moe_method(
132
+ self, layer, layer_name=prefix)
133
+ if isinstance(layer, Attention):
134
+ return CompressedTensorsKVCacheMethod(self)
135
+ return None
@@ -0,0 +1,266 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Union
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import torch
20
+ from compressed_tensors.quantization import QuantizationArgs
21
+ from jax.experimental.layout import Format, Layout
22
+ from jax.sharding import Mesh, NamedSharding
23
+ from jax.sharding import PartitionSpec as P
24
+ from torch.nn.parameter import Parameter
25
+ from torchax.interop import jax_view, torch_view
26
+ from torchax.ops.mappings import t2j
27
+ from vllm.logger import init_logger
28
+ from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
29
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
30
+ CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)
31
+
32
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func
33
+ from tpu_inference.layers.vllm.linear_common import \
34
+ reorder_concatenated_tensor_for_sharding
35
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
36
+ from tpu_inference.layers.vllm.quantization.unquantized import \
37
+ VllmUnquantizedFusedMoEMethod
38
+
39
+ logger = init_logger(__name__)
40
+
41
+
42
+ class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):
43
+
44
+ @staticmethod
45
+ def get_moe_method(
46
+ quant_config: "VllmCompressedTensorsConfig", # type: ignore # noqa E501
47
+ layer: torch.nn.Module,
48
+ layer_name: str,
49
+ ) -> CompressedTensorsMoEMethod:
50
+ assert isinstance(layer, FusedMoE)
51
+
52
+ # FusedMoE was made by combining multiple Linears so need to
53
+ # make sure quantization config for Linear can target it
54
+ quant_config._add_fused_moe_to_target_scheme_map()
55
+ unfused_names = [
56
+ layer_name + proj_name
57
+ for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
58
+ ]
59
+ # TODO: refactor this to use expert_mapping and check all layer numbers
60
+ all_scheme_dicts = [
61
+ quant_config.get_scheme_dict(layer, name) for name in unfused_names
62
+ ]
63
+ scheme_dict = all_scheme_dicts.pop()
64
+
65
+ # multiple schemes found
66
+ if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
67
+ raise ValueError("All MoE projections need to have same "
68
+ "quantization scheme but found multiple")
69
+
70
+ if scheme_dict is None:
71
+ return VllmUnquantizedFusedMoEMethod(layer.moe_config,
72
+ quant_config.mesh)
73
+
74
+ weight_quant = scheme_dict.get("weights")
75
+ input_quant = scheme_dict.get("input_activations")
76
+
77
+ if quant_config._is_fp8_w8a8(weight_quant, input_quant):
78
+ return VllmCompressedTensorsW8A8Fp8MoEMethod(
79
+ weight_quant, input_quant, layer.moe_config, quant_config.mesh)
80
+ else:
81
+ raise RuntimeError(
82
+ f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
83
+
84
+
85
+ class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
86
+ JaxCommonConfig):
87
+
88
+ def __init__(
89
+ self,
90
+ weight_quant: QuantizationArgs,
91
+ input_quant: QuantizationArgs,
92
+ moe: FusedMoEConfig,
93
+ mesh: Mesh,
94
+ ):
95
+ super().__init__(weight_quant, input_quant, moe)
96
+ self.mesh = mesh
97
+
98
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
99
+ """
100
+ Docstring for process_weights_after_loading
101
+
102
+ :param self: Description
103
+ :param layer: Description
104
+ :type layer: torch.nn.Module
105
+
106
+ Steps:
107
+ 1. Read weights from layer object and convert to jax arrays
108
+ 2. Interleave concat w13 weights
109
+ 3. Shard weights for tp (rowwise w13, colwise w2)
110
+ 4. Initialize Params as torch.nn.Parameter
111
+ a. w13_weight - float8_e4m3fn shape: (num_experts, 2 x intermediate_size, input_size)
112
+ b. w2_weight - float8_e4m3fn shape: (num_experts, output_size, intermediate_size)
113
+ c. w13_weight_scale - FP32 shape: (num_experts, 2 x intermediate_size, 1)
114
+ d. w2_weight_scale - FP32shape: (num_experts, output_size, 1)
115
+ """
116
+ assert isinstance(layer, FusedMoE)
117
+
118
+ # Read weights from layer object
119
+ w13_weight = t2j(
120
+ layer.w13_weight, use_dlpack=False
121
+ ) # float8_e4m3fn shape: (num_experts, 2 x intermediate_size, input_size)
122
+ w13_weight_scale = t2j(
123
+ layer.w13_weight_scale, use_dlpack=False
124
+ ) # FP32 shape: (num_experts, 2 x intermediate_size, 1)
125
+ w2_weight = t2j(
126
+ layer.w2_weight, use_dlpack=False
127
+ ) # float8_e4m3fn shape: (num_experts, output_size, intermediate_size)
128
+ w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
129
+ w13_weight_scale = w13_weight_scale.astype(jnp.bfloat16)
130
+ w2_weight_scale = w2_weight_scale.astype(jnp.bfloat16)
131
+ intermediate_size = layer.w13_weight.shape[1] // 2
132
+ assert intermediate_size == w2_weight.shape[-1]
133
+ n_shards = self.mesh.shape["model"]
134
+ assert intermediate_size % n_shards == 0
135
+ num_experts, hidden_size, intermediate_size = w2_weight.shape
136
+ assert w2_weight_scale.shape == (num_experts, hidden_size, 1)
137
+ assert w13_weight.shape == (num_experts, 2 * intermediate_size,
138
+ hidden_size)
139
+ assert w13_weight_scale.shape == (num_experts, 2 * intermediate_size,
140
+ 1)
141
+
142
+ if not layer.use_ep:
143
+ # Interleave concat w13 weights
144
+ w13_weight = reorder_concatenated_tensor_for_sharding(
145
+ w13_weight,
146
+ split_sizes=(intermediate_size, intermediate_size),
147
+ dim=1,
148
+ n_shards=n_shards,
149
+ )
150
+ # Interleave concat w13 weight scales
151
+ w13_weight_scale = reorder_concatenated_tensor_for_sharding(
152
+ w13_weight_scale,
153
+ split_sizes=(intermediate_size, intermediate_size),
154
+ dim=1,
155
+ n_shards=n_shards,
156
+ )
157
+
158
+ # 160,5120,1 -> 160,1,5120
159
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
160
+ # 160,1,5120 -> 160, 1, 1, 5120 (num_experts, num_blocks, 1, outer_dim)
161
+ w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
162
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
163
+ w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
164
+
165
+ if layer.use_ep:
166
+ # Apply EP sharding
167
+ ep_sharding = NamedSharding(self.mesh, P("model"))
168
+
169
+ w13_weight = jax.lax.with_sharding_constraint(
170
+ w13_weight, ep_sharding)
171
+ w2_weight = jax.lax.with_sharding_constraint(
172
+ w2_weight, ep_sharding)
173
+
174
+ w13_weight_scale = jax.lax.with_sharding_constraint(
175
+ w13_weight_scale, ep_sharding)
176
+ w2_weight_scale = jax.lax.with_sharding_constraint(
177
+ w2_weight_scale, ep_sharding)
178
+
179
+ else:
180
+ # Shard weights for tp (rowwise w13, colwise w2)
181
+ w13_format = Format(
182
+ Layout((0, 1, 2)), # expert, 2xintermed, input
183
+ NamedSharding(self.mesh, P(None, "model", None)),
184
+ ) # rowwise sharding on intermed dim
185
+
186
+ w13_scale_format = Format(
187
+ Layout(
188
+ (0, 1, 2, 3)), # (num_experts, num_blocks, 1, outer_dim)
189
+ NamedSharding(self.mesh, P(None, None, None, "model")),
190
+ ) # col wise GMM sharding on intermed dim
191
+
192
+ # Local shard shape: (num_experts, 2 x (intermediate_size // n_shards), input_size)
193
+ w13_weight = jax.lax.with_sharding_constraint(
194
+ w13_weight, w13_format)
195
+ # Local shard shape: (num_experts, (intermediate_size // n_shards), 1)
196
+ w13_weight_scale = jax.lax.with_sharding_constraint(
197
+ w13_weight_scale, w13_scale_format)
198
+
199
+ # Shard weights for tp (colwise w2)
200
+ w2_format = Format(
201
+ Layout((0, 1, 2)), # expert, intermed, hidden
202
+ NamedSharding(self.mesh, P(None, None, "model")),
203
+ )
204
+ # Local shard shape: (num_experts, hidden, (intermediate_size // n_shards))
205
+ # # (num_experts, num_blocks, 1, outer_dim)
206
+ w2_weight = jax.lax.with_sharding_constraint(w2_weight, w2_format)
207
+
208
+ w2_scale_format = Format(
209
+ Layout((0, 1, 2, 3)), # expert, intermed, 1
210
+ NamedSharding(self.mesh, P(None, None, None, None)),
211
+ )
212
+ # Local shard shape: (num_experts, intermediate_size // n_shards, 1)
213
+ w2_weight_scale = jax.lax.with_sharding_constraint(
214
+ w2_weight_scale, w2_scale_format)
215
+
216
+ w13_weight = Parameter(torch_view(w13_weight), requires_grad=False)
217
+ w13_weight_scale = Parameter(torch_view(w13_weight_scale),
218
+ requires_grad=False)
219
+ w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
220
+ w2_weight_scale = Parameter(torch_view(w2_weight_scale),
221
+ requires_grad=False)
222
+
223
+ layer.w13_weight = w13_weight
224
+ layer.w13_weight_scale = w13_weight_scale
225
+ layer.w2_weight = w2_weight
226
+ layer.w2_weight_scale = w2_weight_scale
227
+
228
+ def apply(
229
+ self,
230
+ layer: torch.nn.Module,
231
+ x: torch.Tensor,
232
+ router_logits: torch.Tensor,
233
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
234
+ assert isinstance(layer, FusedMoE)
235
+ if layer.activation != "silu":
236
+ raise NotImplementedError(
237
+ "Only silu is supported for activation function.")
238
+ if layer.scoring_func != "softmax":
239
+ raise NotImplementedError(
240
+ "Only softmax is supported for scoring_func")
241
+
242
+ # TODO: Use MoE kernel when it supports fp8
243
+ x = jax_view(x)
244
+ w13_weight = jax_view(layer.w13_weight)
245
+ w2_weight = jax_view(layer.w2_weight)
246
+ w13_weight_scale = jax_view(layer.w13_weight_scale)
247
+ w2_weight_scale = jax_view(layer.w2_weight_scale)
248
+ gating_output = jax_view(router_logits)
249
+ out = torch_view(
250
+ fused_moe_func(
251
+ hidden_states=x,
252
+ w1=w13_weight,
253
+ w2=w2_weight,
254
+ w1_scale=w13_weight_scale,
255
+ w2_scale=w2_weight_scale,
256
+ w1_bias=None,
257
+ w2_bias=None,
258
+ gating_output=gating_output,
259
+ topk=layer.top_k,
260
+ renormalize=layer.renormalize,
261
+ mesh=self.mesh,
262
+ use_ep=layer.use_ep,
263
+ activation=layer.activation,
264
+ ))
265
+
266
+ return out
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,222 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import torch
20
+ from compressed_tensors.quantization import (QuantizationArgs,
21
+ QuantizationStrategy)
22
+ from jax.sharding import NamedSharding, PartitionSpec
23
+ from torchax.interop import jax_view, torch_view
24
+ from torchax.ops.mappings import t2j
25
+ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
26
+ CompressedTensorsW8A8Fp8
27
+ from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
28
+ per_tensor_dequantize
29
+
30
+ from tpu_inference.layers.vllm.linear_common import (
31
+ sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
32
+ torch_to_jax_param)
33
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
34
+
35
+ P = PartitionSpec
36
+
37
+
38
+ def requantize_with_max_scale(
39
+ weight: torch.Tensor, weight_scale: torch.Tensor,
40
+ logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
41
+ dtype = weight.dtype
42
+ dtype_info = torch.finfo(dtype)
43
+ maxval = float(dtype_info.max)
44
+ minval = float(dtype_info.min)
45
+
46
+ max_w_scale = weight_scale.max()
47
+
48
+ unfused_module_in_checkpoint = (weight_scale[-1]
49
+ > torch.finfo(torch.float8_e4m3fn).min)
50
+
51
+ # If unfused checkpoint, need requanize with the single scale.
52
+ if unfused_module_in_checkpoint:
53
+ start = 0
54
+ for idx, logical_width in enumerate(logical_widths):
55
+ # Skip any component with zero width.
56
+ if logical_width == 0:
57
+ continue
58
+ end = start + logical_width
59
+ weight_dq = per_tensor_dequantize(weight[start:end, :],
60
+ weight_scale[idx])
61
+ weight_q = weight_dq / max_w_scale
62
+ weight[start:end, :] = weight_q.clamp(minval, maxval).to(dtype)
63
+ start = end
64
+
65
+ return max_w_scale, weight
66
+
67
+
68
+ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
69
+
70
+ def __init__(
71
+ self,
72
+ weight_quant: QuantizationArgs,
73
+ is_static_input_scheme: bool,
74
+ jax_config: JaxCommonLinearConfig,
75
+ ):
76
+ super().__init__(weight_quant, is_static_input_scheme)
77
+
78
+ self.jax_config = jax_config
79
+
80
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
81
+ weight = layer.weight
82
+ weight_scale = layer.weight_scale
83
+
84
+ if self.is_static_input_scheme:
85
+ # In static quant, all input_scales share the same value.
86
+ assert layer.input_scale.min() == layer.input_scale.max()
87
+ input_scale_first = layer.input_scale[0]
88
+
89
+ input_scale = jax.device_put(
90
+ t2j(input_scale_first, use_dlpack=False),
91
+ NamedSharding(self.jax_config.mesh, P()))
92
+ input_scale = torch.nn.Parameter(torch_view(input_scale),
93
+ requires_grad=False)
94
+ delattr(layer, "input_scale")
95
+ layer.input_scale = input_scale
96
+
97
+ # TODO(kyuyeunk): Investigate performance gain from merging scales.
98
+ # By merging input and weight scales, we reduce the number of muls
99
+ # required for dequantization from 2 (for each scales) to 1.
100
+ # weight_scale *= input_scale_first
101
+
102
+ if self.strategy == QuantizationStrategy.TENSOR:
103
+ weight_scale, weight = requantize_with_max_scale(
104
+ weight, weight_scale, self.jax_config.output_sizes)
105
+ weight_scale = jax.device_put(
106
+ t2j(weight_scale, use_dlpack=False),
107
+ NamedSharding(self.jax_config.mesh, P()))
108
+ weight_scale = torch.nn.Parameter(torch_view(weight_scale),
109
+ requires_grad=False)
110
+ else:
111
+ weight_scale = weight_scale.squeeze(-1)
112
+ weight_scale = torch_to_jax_param(
113
+ weight_scale,
114
+ NamedSharding(self.jax_config.mesh,
115
+ self.jax_config.bias_sharding),
116
+ self.jax_config.output_sizes, self.jax_config.n_shards,
117
+ self.jax_config.fuse_matmuls)
118
+ delattr(layer, "weight_scale")
119
+ layer.weight_scale = weight_scale
120
+
121
+ weight = torch_to_jax_param(
122
+ layer.weight,
123
+ NamedSharding(self.jax_config.mesh,
124
+ self.jax_config.weight_sharding),
125
+ self.jax_config.output_sizes, self.jax_config.n_shards,
126
+ self.jax_config.fuse_matmuls)
127
+ delattr(layer, "weight")
128
+ layer.weight = weight
129
+
130
+ if layer.bias is not None:
131
+ bias = torch_to_jax_param(
132
+ layer.bias,
133
+ NamedSharding(self.jax_config.mesh,
134
+ self.jax_config.bias_sharding),
135
+ self.jax_config.output_sizes, self.jax_config.n_shards,
136
+ self.jax_config.fuse_matmuls)
137
+ delattr(layer, "bias")
138
+ layer.bias = bias
139
+
140
+ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
141
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
142
+ with jax.named_scope(layer._get_name()):
143
+ if self.jax_config.fuse_matmuls:
144
+ return self._apply_fused(layer, x, bias)
145
+ else:
146
+ return self._apply_split(layer, x, bias)
147
+
148
+ def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
149
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
150
+ x_jax = jax_view(x)
151
+ weight_jax = jax_view(layer.weight)
152
+ weight_scale_jax = jax_view(layer.weight_scale)
153
+
154
+ if self.is_static_input_scheme:
155
+ # TODO(kyuyeunk): Add kernel support for static quant
156
+ input_scale = jax_view(layer.input_scale)
157
+ dtype_info = jnp.finfo(weight_jax.dtype)
158
+ maxval = float(dtype_info.max)
159
+ minval = float(dtype_info.min)
160
+ x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
161
+ maxval).astype(weight_jax.dtype)
162
+
163
+ outs = jax.lax.dot_general(
164
+ x_q,
165
+ weight_jax,
166
+ (((1, ), (1, )), ((), ())),
167
+ preferred_element_type=jnp.float32,
168
+ )
169
+ outs *= weight_scale_jax
170
+ outs = outs.astype(x_jax.dtype)
171
+ else:
172
+ outs = sharded_quantized_matmul(x_jax, weight_jax,
173
+ weight_scale_jax,
174
+ self.jax_config.mesh,
175
+ self.jax_config.weight_sharding)
176
+
177
+ if bias is not None and not layer.skip_bias_add:
178
+ outs += jax_view(bias)
179
+ outs = slice_sharded_tensor_for_concatenation(
180
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
181
+ return torch_view(jnp.concatenate(outs, axis=-1))
182
+
183
+ def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
184
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
185
+ assert isinstance(layer.weight, torch.nn.ParameterList)
186
+
187
+ x_jax = jax_view(x)
188
+ outs = []
189
+ for i, (weight, weight_scale) in enumerate(
190
+ zip(layer.weight, layer.weight_scale)):
191
+ weight_jax = jax_view(weight)
192
+ weight_scale_jax = jax_view(weight_scale)
193
+
194
+ if self.is_static_input_scheme:
195
+ # TODO(kyuyeunk): Add kernel support for static quant
196
+ input_scale = jax_view(layer.input_scale)
197
+ dtype_info = jnp.finfo(weight_jax.dtype)
198
+ maxval = float(dtype_info.max)
199
+ minval = float(dtype_info.min)
200
+ x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
201
+ maxval).astype(weight_jax.dtype)
202
+
203
+ out = jax.lax.dot_general(
204
+ x_q,
205
+ weight_jax,
206
+ (((1, ), (1, )), ((), ())),
207
+ preferred_element_type=jnp.float32,
208
+ )
209
+ # TODO(kyuyeunk): Investigate performance gain from merging scales.
210
+ # out *= weight_scale_jax
211
+ out *= weight_scale_jax * input_scale
212
+ out = out.astype(x_jax.dtype)
213
+ else:
214
+ out = sharded_quantized_matmul(x_jax, weight_jax,
215
+ weight_scale_jax,
216
+ self.jax_config.mesh,
217
+ self.jax_config.weight_sharding)
218
+
219
+ if bias is not None and not layer.skip_bias_add:
220
+ out += jax_view(bias[i])
221
+ outs.append(out)
222
+ return torch_view(jnp.concatenate(outs, axis=-1))