tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,39 @@
1
- from typing import Union
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
2
14
 
3
15
  import jax
4
- import jax.numpy as jnp
5
16
  import torch
6
- import torch.nn.functional as F
7
17
  from compressed_tensors.quantization import QuantizationArgs
8
- from jax.experimental.layout import Format, Layout
9
- from jax.sharding import Mesh, NamedSharding
10
- from jax.sharding import PartitionSpec as P
18
+ from jax.sharding import Mesh
11
19
  from torch.nn.parameter import Parameter
12
- from torchax.interop import call_jax, torch_view
20
+ from torchax.interop import torch_view
13
21
  from torchax.ops.mappings import t2j
14
- from vllm.logger import init_logger
15
22
  from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
16
23
  from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
17
24
  CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)
18
25
 
19
- from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
26
+ from tpu_inference.layers.common.sharding import ShardingAxisName
27
+ from tpu_inference.layers.vllm.fused_moe import (FusedMoEBackend,
28
+ fused_moe_apply,
29
+ select_moe_backend)
30
+ from tpu_inference.layers.vllm.process_weights.fused_moe_weights import (
31
+ FusedMoEWeights, process_moe_weights, shard_moe_weights)
32
+ from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
20
33
  from tpu_inference.layers.vllm.quantization.unquantized import \
21
34
  VllmUnquantizedFusedMoEMethod
35
+ from tpu_inference.logger import init_logger
36
+ from tpu_inference.utils import get_mesh_shape_product
22
37
 
23
38
  logger = init_logger(__name__)
24
39
 
@@ -31,7 +46,6 @@ class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):
31
46
  layer: torch.nn.Module,
32
47
  layer_name: str,
33
48
  ) -> CompressedTensorsMoEMethod:
34
-
35
49
  assert isinstance(layer, FusedMoE)
36
50
 
37
51
  # FusedMoE was made by combining multiple Linears so need to
@@ -68,15 +82,44 @@ class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):
68
82
 
69
83
 
70
84
  class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
71
- JaxCommonConfig):
85
+ VllmQuantConfig):
72
86
 
73
- def __init__(self, weight_quant: QuantizationArgs,
74
- input_quant: QuantizationArgs, moe: FusedMoEConfig,
75
- mesh: Mesh):
87
+ def __init__(
88
+ self,
89
+ weight_quant: QuantizationArgs,
90
+ input_quant: QuantizationArgs,
91
+ moe: FusedMoEConfig,
92
+ mesh: Mesh,
93
+ ):
76
94
  super().__init__(weight_quant, input_quant, moe)
95
+
77
96
  self.mesh = mesh
97
+ self.moe_backend = select_moe_backend(self.moe)
98
+
99
+ self.extra_backend_kwargs = {}
100
+ if self.moe_backend == FusedMoEBackend.FUSED_MOE:
101
+ raise NotImplementedError(
102
+ "Per-channel quantization is not supported in FusedMoE kernel."
103
+ )
78
104
 
79
105
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
106
+ """
107
+ Docstring for process_weights_after_loading
108
+
109
+ :param self: Description
110
+ :param layer: Description
111
+ :type layer: torch.nn.Module
112
+
113
+ Steps:
114
+ 1. Read weights from layer object and convert to jax arrays
115
+ 2. Interleave concat w13 weights
116
+ 3. Shard weights for tp (rowwise w13, colwise w2)
117
+ 4. Initialize Params as torch.nn.Parameter
118
+ a. w13_weight - float8_e4m3fn shape: (num_experts, 2 x intermediate_size, input_size)
119
+ b. w2_weight - float8_e4m3fn shape: (num_experts, output_size, intermediate_size)
120
+ c. w13_weight_scale - FP32 shape: (num_experts, 2 x intermediate_size, 1)
121
+ d. w2_weight_scale - FP32shape: (num_experts, output_size, 1)
122
+ """
80
123
  assert isinstance(layer, FusedMoE)
81
124
 
82
125
  w13_weight = t2j(layer.w13_weight, use_dlpack=False)
@@ -84,129 +127,73 @@ class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
84
127
  w2_weight = t2j(layer.w2_weight, use_dlpack=False)
85
128
  w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
86
129
 
87
- w13_weight_scale = w13_weight_scale.astype(jnp.bfloat16)
88
- w2_weight_scale = w2_weight_scale.astype(jnp.bfloat16)
89
-
90
- num_experts, hidden_size, intermediate_size = w2_weight.shape
91
- assert w2_weight_scale.shape == (num_experts, hidden_size, 1)
92
- assert w13_weight.shape == (num_experts, 2 * intermediate_size,
93
- hidden_size)
94
- assert w13_weight_scale.shape == (num_experts, 2 * intermediate_size,
95
- 1)
96
-
97
- w1_weight, w3_weight = jnp.split(w13_weight, 2, 1)
98
- w1_weight_scale, w3_weight_scale = jnp.split(w13_weight_scale, 2, 1)
99
-
100
- if layer.use_ep:
101
- format = Format(Layout((0, 1, 2)),
102
- NamedSharding(self.mesh, P("model", None, None)))
103
- w1_weight = jax.device_put(w1_weight, format)
104
- w1_weight_scale = jax.device_put(w1_weight_scale, format)
105
- w3_weight = jax.device_put(w3_weight, format)
106
- w3_weight_scale = jax.device_put(w3_weight_scale, format)
107
- w2_weight = jax.device_put(w2_weight, format)
108
- w2_weight_scale = jax.device_put(w2_weight_scale, format)
130
+ if self.moe.has_bias:
131
+ w13_bias = t2j(layer.w13_bias, use_dlpack=False)
132
+ w2_bias = t2j(layer.w2_bias, use_dlpack=False)
109
133
  else:
110
- n_shards = self.mesh.shape["model"]
111
- assert intermediate_size % n_shards == 0
112
-
113
- w13_format = Format(
114
- Layout((0, 1, 2)),
115
- NamedSharding(self.mesh, P(None, "model", None)))
116
- w1_weight = jax.device_put(w1_weight, w13_format)
117
- w1_weight_scale = jax.device_put(w1_weight_scale, w13_format)
118
- w3_weight = jax.device_put(w3_weight, w13_format)
119
- w3_weight_scale = jax.device_put(w3_weight_scale, w13_format)
120
- w2_weight = jax.device_put(
121
- w2_weight,
122
- Format(Layout((0, 1, 2)),
123
- NamedSharding(self.mesh, P(None, None, "model"))),
134
+ w13_bias = w2_bias = None
135
+
136
+ @jax.jit
137
+ def process_fp8_moe_weights(
138
+ w13_weight: jax.Array,
139
+ w13_weight_scale: jax.Array,
140
+ w13_bias: jax.Array | None,
141
+ w2_weight: jax.Array,
142
+ w2_weight_scale: jax.Array,
143
+ w2_bias: jax.Array | None,
144
+ ) -> FusedMoEWeights:
145
+ w13_interleave = layer.activation == "swigluoai"
146
+ w13_reorder_size = get_mesh_shape_product(
147
+ self.mesh, ShardingAxisName.MLP_TENSOR)
148
+
149
+ return process_moe_weights(
150
+ weights=FusedMoEWeights(
151
+ w13_weight=w13_weight,
152
+ w13_weight_scale=w13_weight_scale,
153
+ w13_bias=w13_bias,
154
+ w2_weight=w2_weight,
155
+ w2_weight_scale=w2_weight_scale,
156
+ w2_bias=w2_bias,
157
+ ),
158
+ moe_backend=self.moe_backend,
159
+ w13_reorder_size=w13_reorder_size,
160
+ w13_interleave=w13_interleave,
124
161
  )
125
- w2_weight_scale = jax.device_put(
126
- w2_weight_scale,
127
- Format(Layout((0, 1, 2)), NamedSharding(self.mesh, P())),
128
- ) # replicate
129
-
130
- w1_weight = Parameter(torch_view(w1_weight), requires_grad=False)
131
- w1_weight_scale = Parameter(torch_view(w1_weight_scale),
132
- requires_grad=False)
133
- w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
134
- w2_weight_scale = Parameter(torch_view(w2_weight_scale),
135
- requires_grad=False)
136
- w3_weight = Parameter(torch_view(w3_weight), requires_grad=False)
137
- w3_weight_scale = Parameter(torch_view(w3_weight_scale),
138
- requires_grad=False)
139
-
140
- # TODO dont reuse variable
141
- layer.w13_weight = w1_weight
142
- layer.w13_weight_scale = w1_weight_scale
143
- layer.w2_weight = w2_weight
144
- layer.w2_weight_scale = w2_weight_scale
145
- layer.w3_weight = w3_weight
146
- layer.w3_weight_scale = w3_weight_scale
162
+
163
+ weights = process_fp8_moe_weights(
164
+ w13_weight,
165
+ w13_weight_scale,
166
+ w13_bias,
167
+ w2_weight,
168
+ w2_weight_scale,
169
+ w2_bias,
170
+ )
171
+ weights = torch_view(
172
+ shard_moe_weights(weights, self.moe_backend, self.mesh))
173
+
174
+ layer.w13_weight = Parameter(weights.w13_weight, requires_grad=False)
175
+ layer.w2_weight = Parameter(weights.w2_weight, requires_grad=False)
176
+
177
+ layer.w13_weight_scale = Parameter(weights.w13_weight_scale,
178
+ requires_grad=False)
179
+ layer.w2_weight_scale = Parameter(weights.w2_weight_scale,
180
+ requires_grad=False)
181
+
182
+ if self.moe.has_bias:
183
+ layer.w13_bias = Parameter(weights.w13_bias, requires_grad=False)
184
+ layer.w2_bias = Parameter(weights.w2_bias, requires_grad=False)
147
185
 
148
186
  def apply(
149
187
  self,
150
188
  layer: torch.nn.Module,
151
189
  x: torch.Tensor,
152
190
  router_logits: torch.Tensor,
153
- ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
154
- assert isinstance(layer, FusedMoE)
155
- if layer.activation != "silu":
156
- raise NotImplementedError(
157
- "Only silu is supported for activation function.")
158
- if layer.scoring_func != "softmax":
159
- raise NotImplementedError(
160
- "Only softmax is supported for scoring_func")
161
-
162
- # TODO: Use MoE kernel when it supports fp8
163
- seqlen = x.shape[0]
164
-
165
- expert_weights = F.softmax(router_logits, dim=-1)
166
- expert_weights, expert_indices = torch.topk(expert_weights,
167
- layer.top_k,
168
- dim=-1)
169
- if layer.renormalize:
170
- expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
171
-
172
- # cond ffn
173
- # e = total num of exp = 160
174
- # t = seqlen
175
- # o = config.imtermediate size
176
- # i = config.dim
177
- #torch.einsum("ti, eoi -> teo", x, layer.w13_weight) * self.w13_weight_scale)
178
- ux1 = call_jax(jax.lax.dot,
179
- x,
180
- layer.w13_weight,
181
- dimension_numbers=(((1, ), (2, )), ((), ())),
182
- preferred_element_type=jnp.bfloat16.dtype)
183
- x1 = F.silu(ux1 * layer.w13_weight_scale.squeeze(2))
184
-
185
- #x3 = torch.einsum("ti, eoi -> teo", x, layer.w3_weight) * self.w3_weight_scale
186
- x3 = call_jax(jax.lax.dot,
187
- x,
188
- layer.w3_weight,
189
- dimension_numbers=(((1, ), (2, )), ((), ())),
190
- preferred_element_type=jnp.bfloat16.dtype
191
- ) * layer.w3_weight_scale.squeeze(2)
192
-
193
- #expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2_weight) * self.w2_weight_scale
194
- expert_outs = call_jax(
195
- jax.lax.dot,
196
- x1 * x3,
197
- layer.w2_weight,
198
- dimension_numbers=(((2, ), (2, )), ((1, ), (0, ))),
199
- preferred_element_type=jnp.bfloat16.dtype).transpose(
200
- 0, 1) * layer.w2_weight_scale.squeeze(2)
201
-
202
- seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
203
- expert_outs = expert_outs[seq_indexes, expert_indices]
204
-
205
- # out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
206
- out = call_jax(jax.lax.dot,
207
- expert_outs,
208
- expert_weights,
209
- dimension_numbers=(((1, ), (1, )), ((0, ), (0, ))),
210
- preferred_element_type=jnp.bfloat16.dtype)
211
-
212
- return out
191
+ ) -> torch.Tensor:
192
+ return fused_moe_apply(
193
+ layer,
194
+ x,
195
+ router_logits,
196
+ self.moe_backend,
197
+ self.mesh,
198
+ self.extra_backend_kwargs,
199
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional
2
16
 
3
17
  import jax
@@ -6,49 +20,27 @@ import torch
6
20
  from compressed_tensors.quantization import (QuantizationArgs,
7
21
  QuantizationStrategy)
8
22
  from jax.sharding import NamedSharding, PartitionSpec
23
+ from torch.nn.parameter import Parameter
9
24
  from torchax.interop import jax_view, torch_view
10
25
  from torchax.ops.mappings import t2j
11
26
  from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
12
27
  CompressedTensorsW8A8Fp8
13
- from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
14
- per_tensor_dequantize
15
28
 
16
- from tpu_inference.layers.vllm.linear_common import (
17
- sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
18
- torch_to_jax_param)
19
- from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
29
+ from tpu_inference.layers.common.quantization import (dequantize_tensor,
30
+ quantize_tensor)
31
+ from tpu_inference.layers.common.utils import \
32
+ slice_sharded_tensor_for_concatenation
33
+ from tpu_inference.layers.vllm.linear import sharded_quantized_matmul
34
+ from tpu_inference.layers.vllm.process_weights.linear_weights import (
35
+ LinearWeights, process_lienar_weights, shard_linear_weights,
36
+ to_parameter_list)
37
+ from tpu_inference.layers.vllm.quantization.configs import \
38
+ VllmQuantLinearConfig
39
+ from tpu_inference.logger import init_logger
20
40
 
21
41
  P = PartitionSpec
22
42
 
23
-
24
- def requantize_with_max_scale(
25
- weight: torch.Tensor, weight_scale: torch.Tensor,
26
- logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
27
- dtype = weight.dtype
28
- dtype_info = torch.finfo(dtype)
29
- maxval = float(dtype_info.max)
30
- minval = float(dtype_info.min)
31
-
32
- max_w_scale = weight_scale.max()
33
-
34
- unfused_module_in_checkpoint = (weight_scale[-1]
35
- > torch.finfo(torch.float8_e4m3fn).min)
36
-
37
- # If unfused checkpoint, need requanize with the single scale.
38
- if unfused_module_in_checkpoint:
39
- start = 0
40
- for idx, logical_width in enumerate(logical_widths):
41
- # Skip any component with zero width.
42
- if logical_width == 0:
43
- continue
44
- end = start + logical_width
45
- weight_dq = per_tensor_dequantize(weight[start:end, :],
46
- weight_scale[idx])
47
- weight_q = weight_dq / max_w_scale
48
- weight[start:end, :] = weight_q.clamp(minval, maxval).to(dtype)
49
- start = end
50
-
51
- return max_w_scale, weight
43
+ logger = init_logger(__name__)
52
44
 
53
45
 
54
46
  class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
@@ -57,15 +49,86 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
57
49
  self,
58
50
  weight_quant: QuantizationArgs,
59
51
  is_static_input_scheme: bool,
60
- jax_config: JaxCommonLinearConfig,
52
+ linear_config: VllmQuantLinearConfig,
61
53
  ):
62
54
  super().__init__(weight_quant, is_static_input_scheme)
63
55
 
64
- self.jax_config = jax_config
56
+ self.linear_config = linear_config
65
57
 
66
58
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
67
- weight = layer.weight
68
- weight_scale = layer.weight_scale
59
+ weight = t2j(layer.weight, use_dlpack=False)
60
+ delattr(layer, "weight")
61
+ weight_scale = t2j(layer.weight_scale, use_dlpack=False)
62
+ delattr(layer, "weight_scale")
63
+
64
+ if layer.bias is not None and not layer.skip_bias_add:
65
+ if layer.return_bias:
66
+ logger.warning_once("Bias might return incorrect value.")
67
+ bias = t2j(layer.bias, use_dlpack=False)
68
+ delattr(layer, "bias")
69
+ else:
70
+ bias = None
71
+
72
+ per_tensor = self.strategy == QuantizationStrategy.TENSOR
73
+
74
+ @jax.jit
75
+ def process_fp8_linear_weights(
76
+ weight: jax.Array,
77
+ weight_scale: jax.Array,
78
+ bias: jax.Array | None,
79
+ ) -> LinearWeights:
80
+ if per_tensor:
81
+ weights = []
82
+ start = 0
83
+ # Multiple weights may have been concatenated. Loop through
84
+ # each weight and perform dequantization.
85
+ for i, output_size in enumerate(
86
+ self.linear_config.output_sizes):
87
+ end = start + output_size
88
+ weights.append(
89
+ dequantize_tensor(weight[start:end], weight_scale[i]))
90
+ start = end
91
+ weight = jnp.concat(weights, axis=0)
92
+ # Requantize into per-tensor.
93
+ weight, weight_scale = quantize_tensor(jnp.float8_e4m3fn,
94
+ weight, None)
95
+ else:
96
+ weight_scale = jnp.squeeze(weight_scale, -1)
97
+
98
+ return process_lienar_weights(
99
+ LinearWeights(
100
+ weight=weight,
101
+ weight_scale=weight_scale,
102
+ zero_point=None,
103
+ bias=bias,
104
+ ),
105
+ fused=self.linear_config.fuse_matmuls,
106
+ output_sizes=self.linear_config.output_sizes,
107
+ reorder_size=self.linear_config.n_shards,
108
+ per_tensor=per_tensor,
109
+ )
110
+
111
+ weights = process_fp8_linear_weights(weight, weight_scale, bias)
112
+ weights = torch_view(
113
+ shard_linear_weights(
114
+ weights,
115
+ mesh=self.linear_config.mesh,
116
+ weight_p_spec=self.linear_config.weight_sharding,
117
+ bias_p_spec=self.linear_config.bias_sharding,
118
+ per_tensor=per_tensor,
119
+ ))
120
+
121
+ if self.linear_config.fuse_matmuls:
122
+ layer.weight = Parameter(weights.weight, requires_grad=False)
123
+ layer.weight_scale = Parameter(weights.weight_scale,
124
+ requires_grad=False)
125
+ if bias is not None:
126
+ layer.bias = Parameter(weights.bias, requires_grad=False)
127
+ else:
128
+ layer.weight = to_parameter_list(weights.weight)
129
+ layer.weight_scale = to_parameter_list(weights.weight_scale)
130
+ if bias is not None:
131
+ layer.bias = to_parameter_list(weights.bias)
69
132
 
70
133
  if self.is_static_input_scheme:
71
134
  # In static quant, all input_scales share the same value.
@@ -74,59 +137,16 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
74
137
 
75
138
  input_scale = jax.device_put(
76
139
  t2j(input_scale_first, use_dlpack=False),
77
- NamedSharding(self.jax_config.mesh, P()))
140
+ NamedSharding(self.linear_config.mesh, P()))
78
141
  input_scale = torch.nn.Parameter(torch_view(input_scale),
79
142
  requires_grad=False)
80
143
  delattr(layer, "input_scale")
81
144
  layer.input_scale = input_scale
82
145
 
83
- # TODO(kyuyeunk): Investigate performance gain from merging scales.
84
- # By merging input and weight scales, we reduce the number of muls
85
- # required for dequantization from 2 (for each scales) to 1.
86
- # weight_scale *= input_scale_first
87
-
88
- if self.strategy == QuantizationStrategy.TENSOR:
89
- weight_scale, weight = requantize_with_max_scale(
90
- weight, weight_scale, self.jax_config.output_sizes)
91
- weight_scale = jax.device_put(
92
- t2j(weight_scale, use_dlpack=False),
93
- NamedSharding(self.jax_config.mesh, P()))
94
- weight_scale = torch.nn.Parameter(torch_view(weight_scale),
95
- requires_grad=False)
96
- else:
97
- weight_scale = weight_scale.squeeze(-1)
98
- weight_scale = torch_to_jax_param(
99
- weight_scale,
100
- NamedSharding(self.jax_config.mesh,
101
- self.jax_config.bias_sharding),
102
- self.jax_config.output_sizes, self.jax_config.n_shards,
103
- self.jax_config.fuse_matmuls)
104
- delattr(layer, "weight_scale")
105
- layer.weight_scale = weight_scale
106
-
107
- weight = torch_to_jax_param(
108
- layer.weight,
109
- NamedSharding(self.jax_config.mesh,
110
- self.jax_config.weight_sharding),
111
- self.jax_config.output_sizes, self.jax_config.n_shards,
112
- self.jax_config.fuse_matmuls)
113
- delattr(layer, "weight")
114
- layer.weight = weight
115
-
116
- if layer.bias is not None:
117
- bias = torch_to_jax_param(
118
- layer.bias,
119
- NamedSharding(self.jax_config.mesh,
120
- self.jax_config.bias_sharding),
121
- self.jax_config.output_sizes, self.jax_config.n_shards,
122
- self.jax_config.fuse_matmuls)
123
- delattr(layer, "bias")
124
- layer.bias = bias
125
-
126
146
  def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
127
147
  bias: Optional[torch.Tensor]) -> torch.Tensor:
128
148
  with jax.named_scope(layer._get_name()):
129
- if self.jax_config.fuse_matmuls:
149
+ if self.linear_config.fuse_matmuls:
130
150
  return self._apply_fused(layer, x, bias)
131
151
  else:
132
152
  return self._apply_split(layer, x, bias)
@@ -157,13 +177,13 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
157
177
  else:
158
178
  outs = sharded_quantized_matmul(x_jax, weight_jax,
159
179
  weight_scale_jax,
160
- self.jax_config.mesh,
161
- self.jax_config.weight_sharding)
180
+ self.linear_config.mesh,
181
+ self.linear_config.weight_sharding)
162
182
 
163
183
  if bias is not None and not layer.skip_bias_add:
164
184
  outs += jax_view(bias)
165
185
  outs = slice_sharded_tensor_for_concatenation(
166
- outs, self.jax_config.output_sizes, self.jax_config.n_shards)
186
+ outs, self.linear_config.output_sizes, self.linear_config.n_shards)
167
187
  return torch_view(jnp.concatenate(outs, axis=-1))
168
188
 
169
189
  def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
@@ -197,10 +217,10 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
197
217
  out *= weight_scale_jax * input_scale
198
218
  out = out.astype(x_jax.dtype)
199
219
  else:
200
- out = sharded_quantized_matmul(x_jax, weight_jax,
201
- weight_scale_jax,
202
- self.jax_config.mesh,
203
- self.jax_config.weight_sharding)
220
+ out = sharded_quantized_matmul(
221
+ x_jax, weight_jax, weight_scale_jax,
222
+ self.linear_config.mesh,
223
+ self.linear_config.weight_sharding)
204
224
 
205
225
  if bias is not None and not layer.skip_bias_add:
206
226
  out += jax_view(bias[i])