tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,203 +1,266 @@
1
- from typing import Callable, Optional, Union
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Union
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
5
19
  import torch
6
- import torch.nn.functional as F
20
+ from compressed_tensors.quantization import QuantizationArgs
7
21
  from jax.experimental.layout import Format, Layout
8
22
  from jax.sharding import Mesh, NamedSharding
9
23
  from jax.sharding import PartitionSpec as P
10
24
  from torch.nn.parameter import Parameter
11
- from torchax.interop import call_jax, torch_view
25
+ from torchax.interop import jax_view, torch_view
12
26
  from torchax.ops.mappings import t2j
13
27
  from vllm.logger import init_logger
14
28
  from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
15
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
16
- CompressedTensorsConfig
17
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \
18
- CompressedTensorsW8A8Fp8MoEMethod
19
- from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
20
- WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
29
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
30
+ CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)
21
31
 
32
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func
33
+ from tpu_inference.layers.vllm.linear_common import \
34
+ reorder_concatenated_tensor_for_sharding
22
35
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
36
+ from tpu_inference.layers.vllm.quantization.unquantized import \
37
+ VllmUnquantizedFusedMoEMethod
23
38
 
24
39
  logger = init_logger(__name__)
25
40
 
26
41
 
42
+ class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):
43
+
44
+ @staticmethod
45
+ def get_moe_method(
46
+ quant_config: "VllmCompressedTensorsConfig", # type: ignore # noqa E501
47
+ layer: torch.nn.Module,
48
+ layer_name: str,
49
+ ) -> CompressedTensorsMoEMethod:
50
+ assert isinstance(layer, FusedMoE)
51
+
52
+ # FusedMoE was made by combining multiple Linears so need to
53
+ # make sure quantization config for Linear can target it
54
+ quant_config._add_fused_moe_to_target_scheme_map()
55
+ unfused_names = [
56
+ layer_name + proj_name
57
+ for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
58
+ ]
59
+ # TODO: refactor this to use expert_mapping and check all layer numbers
60
+ all_scheme_dicts = [
61
+ quant_config.get_scheme_dict(layer, name) for name in unfused_names
62
+ ]
63
+ scheme_dict = all_scheme_dicts.pop()
64
+
65
+ # multiple schemes found
66
+ if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
67
+ raise ValueError("All MoE projections need to have same "
68
+ "quantization scheme but found multiple")
69
+
70
+ if scheme_dict is None:
71
+ return VllmUnquantizedFusedMoEMethod(layer.moe_config,
72
+ quant_config.mesh)
73
+
74
+ weight_quant = scheme_dict.get("weights")
75
+ input_quant = scheme_dict.get("input_activations")
76
+
77
+ if quant_config._is_fp8_w8a8(weight_quant, input_quant):
78
+ return VllmCompressedTensorsW8A8Fp8MoEMethod(
79
+ weight_quant, input_quant, layer.moe_config, quant_config.mesh)
80
+ else:
81
+ raise RuntimeError(
82
+ f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
83
+
84
+
27
85
  class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
28
86
  JaxCommonConfig):
29
87
 
30
- def __init__(self, quant_config: "CompressedTensorsConfig",
31
- moe: FusedMoEConfig, mesh: Mesh):
32
- super().__init__(quant_config, moe)
88
+ def __init__(
89
+ self,
90
+ weight_quant: QuantizationArgs,
91
+ input_quant: QuantizationArgs,
92
+ moe: FusedMoEConfig,
93
+ mesh: Mesh,
94
+ ):
95
+ super().__init__(weight_quant, input_quant, moe)
33
96
  self.mesh = mesh
34
- self.quant_config = quant_config
35
-
36
- # disable GPU paths
37
- self.use_marlin = False
38
- self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
39
- self.is_fp8_w8a8_sm100 = False
40
- self.use_cutlass = False
41
- self.disable_expert_map = False
42
97
 
43
98
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
99
+ """
100
+ Docstring for process_weights_after_loading
101
+
102
+ :param self: Description
103
+ :param layer: Description
104
+ :type layer: torch.nn.Module
105
+
106
+ Steps:
107
+ 1. Read weights from layer object and convert to jax arrays
108
+ 2. Interleave concat w13 weights
109
+ 3. Shard weights for tp (rowwise w13, colwise w2)
110
+ 4. Initialize Params as torch.nn.Parameter
111
+ a. w13_weight - float8_e4m3fn shape: (num_experts, 2 x intermediate_size, input_size)
112
+ b. w2_weight - float8_e4m3fn shape: (num_experts, output_size, intermediate_size)
113
+ c. w13_weight_scale - FP32 shape: (num_experts, 2 x intermediate_size, 1)
114
+ d. w2_weight_scale - FP32shape: (num_experts, output_size, 1)
115
+ """
44
116
  assert isinstance(layer, FusedMoE)
45
117
 
118
+ # Read weights from layer object
119
+ w13_weight = t2j(
120
+ layer.w13_weight, use_dlpack=False
121
+ ) # float8_e4m3fn shape: (num_experts, 2 x intermediate_size, input_size)
122
+ w13_weight_scale = t2j(
123
+ layer.w13_weight_scale, use_dlpack=False
124
+ ) # FP32 shape: (num_experts, 2 x intermediate_size, 1)
125
+ w2_weight = t2j(
126
+ layer.w2_weight, use_dlpack=False
127
+ ) # float8_e4m3fn shape: (num_experts, output_size, intermediate_size)
128
+ w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
129
+ w13_weight_scale = w13_weight_scale.astype(jnp.bfloat16)
130
+ w2_weight_scale = w2_weight_scale.astype(jnp.bfloat16)
46
131
  intermediate_size = layer.w13_weight.shape[1] // 2
47
- w1_weight = layer.w13_weight[:, :intermediate_size]
48
- w3_weight = layer.w13_weight[:, intermediate_size:]
49
- w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size]
50
- w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:]
51
-
52
- w2_weight = t2j(layer.w2_weight, use_dlpack=False)
53
- w2_weight_scale = t2j(layer.w2_weight_scale.to(torch.bfloat16),
54
- use_dlpack=False)
55
- w1_weight = t2j(w1_weight, use_dlpack=False)
56
- w1_weight_scale = t2j(w1_weight_scale.to(torch.bfloat16),
57
- use_dlpack=False)
58
- w3_weight = t2j(w3_weight, use_dlpack=False)
59
- w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16),
60
- use_dlpack=False)
132
+ assert intermediate_size == w2_weight.shape[-1]
133
+ n_shards = self.mesh.shape["model"]
134
+ assert intermediate_size % n_shards == 0
135
+ num_experts, hidden_size, intermediate_size = w2_weight.shape
136
+ assert w2_weight_scale.shape == (num_experts, hidden_size, 1)
137
+ assert w13_weight.shape == (num_experts, 2 * intermediate_size,
138
+ hidden_size)
139
+ assert w13_weight_scale.shape == (num_experts, 2 * intermediate_size,
140
+ 1)
141
+
142
+ if not layer.use_ep:
143
+ # Interleave concat w13 weights
144
+ w13_weight = reorder_concatenated_tensor_for_sharding(
145
+ w13_weight,
146
+ split_sizes=(intermediate_size, intermediate_size),
147
+ dim=1,
148
+ n_shards=n_shards,
149
+ )
150
+ # Interleave concat w13 weight scales
151
+ w13_weight_scale = reorder_concatenated_tensor_for_sharding(
152
+ w13_weight_scale,
153
+ split_sizes=(intermediate_size, intermediate_size),
154
+ dim=1,
155
+ n_shards=n_shards,
156
+ )
157
+
158
+ # 160,5120,1 -> 160,1,5120
159
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
160
+ # 160,1,5120 -> 160, 1, 1, 5120 (num_experts, num_blocks, 1, outer_dim)
161
+ w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
162
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
163
+ w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
61
164
 
62
165
  if layer.use_ep:
63
- format = Format(Layout((0, 1, 2)),
64
- NamedSharding(self.mesh, P("model", None, None)))
65
- w1_weight = jax.device_put(w1_weight, format)
66
- w1_weight_scale = jax.device_put(w1_weight_scale, format)
67
- w3_weight = jax.device_put(w3_weight, format)
68
- w3_weight_scale = jax.device_put(w3_weight_scale, format)
69
- w2_weight = jax.device_put(w2_weight, format)
70
- w2_weight_scale = jax.device_put(w2_weight_scale, format)
71
- else:
72
- assert intermediate_size == w2_weight.shape[-1]
73
- n_shards = self.mesh.shape["model"]
74
- assert intermediate_size % n_shards == 0
166
+ # Apply EP sharding
167
+ ep_sharding = NamedSharding(self.mesh, P("model"))
168
+
169
+ w13_weight = jax.lax.with_sharding_constraint(
170
+ w13_weight, ep_sharding)
171
+ w2_weight = jax.lax.with_sharding_constraint(
172
+ w2_weight, ep_sharding)
75
173
 
76
- # TODO: enable this if using fused weights
77
- # output_sizes = [intermediate_size, intermediate_size]
78
- # w13_weight = reorder_concatenated_tensor_for_sharding(
79
- # w13_weight, output_sizes, n_shards, dim=1
80
- # )
174
+ w13_weight_scale = jax.lax.with_sharding_constraint(
175
+ w13_weight_scale, ep_sharding)
176
+ w2_weight_scale = jax.lax.with_sharding_constraint(
177
+ w2_weight_scale, ep_sharding)
81
178
 
179
+ else:
180
+ # Shard weights for tp (rowwise w13, colwise w2)
82
181
  w13_format = Format(
83
- Layout((0, 1, 2)),
84
- NamedSharding(self.mesh, P(None, "model", None)))
85
- w1_weight = jax.device_put(w1_weight, w13_format)
86
- w1_weight_scale = jax.device_put(w1_weight_scale, w13_format)
87
- w3_weight = jax.device_put(w3_weight, w13_format)
88
- w3_weight_scale = jax.device_put(w3_weight_scale, w13_format)
89
- w2_weight = jax.device_put(
90
- w2_weight,
91
- Format(Layout((0, 1, 2)),
92
- NamedSharding(self.mesh, P(None, None, "model"))),
182
+ Layout((0, 1, 2)), # expert, 2xintermed, input
183
+ NamedSharding(self.mesh, P(None, "model", None)),
184
+ ) # rowwise sharding on intermed dim
185
+
186
+ w13_scale_format = Format(
187
+ Layout(
188
+ (0, 1, 2, 3)), # (num_experts, num_blocks, 1, outer_dim)
189
+ NamedSharding(self.mesh, P(None, None, None, "model")),
190
+ ) # col wise GMM sharding on intermed dim
191
+
192
+ # Local shard shape: (num_experts, 2 x (intermediate_size // n_shards), input_size)
193
+ w13_weight = jax.lax.with_sharding_constraint(
194
+ w13_weight, w13_format)
195
+ # Local shard shape: (num_experts, (intermediate_size // n_shards), 1)
196
+ w13_weight_scale = jax.lax.with_sharding_constraint(
197
+ w13_weight_scale, w13_scale_format)
198
+
199
+ # Shard weights for tp (colwise w2)
200
+ w2_format = Format(
201
+ Layout((0, 1, 2)), # expert, intermed, hidden
202
+ NamedSharding(self.mesh, P(None, None, "model")),
93
203
  )
94
- w2_weight_scale = jax.device_put(
95
- w2_weight_scale,
96
- Format(Layout((0, 1, 2)), NamedSharding(self.mesh, P())),
97
- ) # replicate
204
+ # Local shard shape: (num_experts, hidden, (intermediate_size // n_shards))
205
+ # # (num_experts, num_blocks, 1, outer_dim)
206
+ w2_weight = jax.lax.with_sharding_constraint(w2_weight, w2_format)
98
207
 
99
- w1_weight = Parameter(torch_view(w1_weight), requires_grad=False)
100
- w1_weight_scale = Parameter(torch_view(w1_weight_scale),
101
- requires_grad=False)
208
+ w2_scale_format = Format(
209
+ Layout((0, 1, 2, 3)), # expert, intermed, 1
210
+ NamedSharding(self.mesh, P(None, None, None, None)),
211
+ )
212
+ # Local shard shape: (num_experts, intermediate_size // n_shards, 1)
213
+ w2_weight_scale = jax.lax.with_sharding_constraint(
214
+ w2_weight_scale, w2_scale_format)
215
+
216
+ w13_weight = Parameter(torch_view(w13_weight), requires_grad=False)
217
+ w13_weight_scale = Parameter(torch_view(w13_weight_scale),
218
+ requires_grad=False)
102
219
  w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
103
220
  w2_weight_scale = Parameter(torch_view(w2_weight_scale),
104
221
  requires_grad=False)
105
- w3_weight = Parameter(torch_view(w3_weight), requires_grad=False)
106
- w3_weight_scale = Parameter(torch_view(w3_weight_scale),
107
- requires_grad=False)
108
222
 
109
- # TODO dont reuse variable
110
- layer.w13_weight = w1_weight
111
- layer.w13_weight_scale = w1_weight_scale
223
+ layer.w13_weight = w13_weight
224
+ layer.w13_weight_scale = w13_weight_scale
112
225
  layer.w2_weight = w2_weight
113
226
  layer.w2_weight_scale = w2_weight_scale
114
- layer.w3_weight = w3_weight
115
- layer.w3_weight_scale = w3_weight_scale
116
227
 
117
228
  def apply(
118
229
  self,
119
230
  layer: torch.nn.Module,
120
231
  x: torch.Tensor,
121
232
  router_logits: torch.Tensor,
122
- top_k: int,
123
- renormalize: bool,
124
- use_grouped_topk: bool = False,
125
- topk_group: Optional[int] = None,
126
- num_expert_group: Optional[int] = None,
127
- global_num_experts: int = -1,
128
- expert_map: Optional[torch.Tensor] = None,
129
- custom_routing_function: Optional[Callable] = None,
130
- scoring_func: str = "softmax",
131
- routed_scaling_factor: float = 1.0,
132
- e_score_correction_bias: Optional[torch.Tensor] = None,
133
- apply_router_weight_on_input: bool = False,
134
- activation: str = "silu",
135
- enable_eplb: bool = False,
136
- expert_load_view: Optional[torch.Tensor] = None,
137
- logical_to_physical_map: Optional[torch.Tensor] = None,
138
- logical_replica_count: Optional[torch.Tensor] = None,
139
233
  ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
140
234
  assert isinstance(layer, FusedMoE)
141
- if activation != "silu":
235
+ if layer.activation != "silu":
142
236
  raise NotImplementedError(
143
237
  "Only silu is supported for activation function.")
144
- if scoring_func != "softmax":
238
+ if layer.scoring_func != "softmax":
145
239
  raise NotImplementedError(
146
240
  "Only softmax is supported for scoring_func")
147
241
 
148
- # import sys
149
- # sys.stdin = open(0)
150
- # breakpoint()
151
-
152
242
  # TODO: Use MoE kernel when it supports fp8
153
-
154
- seqlen = x.shape[0]
155
-
156
- expert_weights = F.softmax(router_logits, dim=-1)
157
- expert_weights, expert_indices = torch.topk(expert_weights,
158
- top_k,
159
- dim=-1)
160
- if renormalize:
161
- expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
162
-
163
- # cond ffn
164
- # e = total num of exp = 160
165
- # t = seqlen
166
- # o = config.imtermediate size
167
- # i = config.dim
168
- #torch.einsum("ti, eoi -> teo", x, layer.w13_weight) * self.w13_weight_scale)
169
- ux1 = call_jax(jax.lax.dot,
170
- x,
171
- layer.w13_weight,
172
- dimension_numbers=(((1, ), (2, )), ((), ())),
173
- preferred_element_type=jnp.bfloat16.dtype)
174
- x1 = F.silu(ux1 * layer.w13_weight_scale.squeeze(2))
175
-
176
- #x3 = torch.einsum("ti, eoi -> teo", x, layer.w3_weight) * self.w3_weight_scale
177
- x3 = call_jax(jax.lax.dot,
178
- x,
179
- layer.w3_weight,
180
- dimension_numbers=(((1, ), (2, )), ((), ())),
181
- preferred_element_type=jnp.bfloat16.dtype
182
- ) * layer.w3_weight_scale.squeeze(2)
183
-
184
- #expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2_weight) * self.w2_weight_scale
185
- expert_outs = call_jax(
186
- jax.lax.dot,
187
- x1 * x3,
188
- layer.w2_weight,
189
- dimension_numbers=(((2, ), (2, )), ((1, ), (0, ))),
190
- preferred_element_type=jnp.bfloat16.dtype).transpose(
191
- 0, 1) * layer.w2_weight_scale.squeeze(2)
192
-
193
- seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
194
- expert_outs = expert_outs[seq_indexes, expert_indices]
195
-
196
- # out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
197
- out = call_jax(jax.lax.dot,
198
- expert_outs,
199
- expert_weights,
200
- dimension_numbers=(((1, ), (1, )), ((0, ), (0, ))),
201
- preferred_element_type=jnp.bfloat16.dtype)
243
+ x = jax_view(x)
244
+ w13_weight = jax_view(layer.w13_weight)
245
+ w2_weight = jax_view(layer.w2_weight)
246
+ w13_weight_scale = jax_view(layer.w13_weight_scale)
247
+ w2_weight_scale = jax_view(layer.w2_weight_scale)
248
+ gating_output = jax_view(router_logits)
249
+ out = torch_view(
250
+ fused_moe_func(
251
+ hidden_states=x,
252
+ w1=w13_weight,
253
+ w2=w2_weight,
254
+ w1_scale=w13_weight_scale,
255
+ w2_scale=w2_weight_scale,
256
+ w1_bias=None,
257
+ w2_bias=None,
258
+ gating_output=gating_output,
259
+ topk=layer.top_k,
260
+ renormalize=layer.renormalize,
261
+ mesh=self.mesh,
262
+ use_ep=layer.use_ep,
263
+ activation=layer.activation,
264
+ ))
202
265
 
203
266
  return out
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional
2
16
 
3
17
  import jax
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional
2
16
 
3
17
  import jax
@@ -0,0 +1,118 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union
16
+
17
+ import jax
18
+ import torch
19
+ from jax.sharding import PartitionSpec
20
+ from vllm.logger import init_logger
21
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
22
+ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
23
+ from vllm.model_executor.layers.quantization import \
24
+ register_quantization_config
25
+ from vllm.model_executor.layers.quantization.base_config import \
26
+ QuantizeMethodBase
27
+ from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
28
+ Fp8LinearMethod)
29
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
30
+ is_layer_skipped
31
+
32
+ from tpu_inference.layers.common.quant_methods import FP8, get_tpu_quant_method
33
+ from tpu_inference.layers.vllm.quantization.common import (
34
+ JaxCommonConfig, JaxCommonLinearConfig)
35
+ from tpu_inference.layers.vllm.quantization.unquantized import \
36
+ VllmUnquantizedLinearMethod
37
+
38
+ P = PartitionSpec
39
+ logger = init_logger(__name__)
40
+
41
+
42
+ @register_quantization_config(get_tpu_quant_method(FP8))
43
+ class VllmFp8Config(Fp8Config, JaxCommonConfig):
44
+
45
+ @classmethod
46
+ def get_name(cls):
47
+ return FP8
48
+
49
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
50
+ return [torch.bfloat16]
51
+
52
+ def get_quant_method(
53
+ self, layer: torch.nn.Module, prefix: str
54
+ ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
55
+ if isinstance(layer, LinearBase):
56
+ linear_config = self.get_linear_config(layer)
57
+ if is_layer_skipped(prefix, self.ignored_layers):
58
+ return VllmUnquantizedLinearMethod(linear_config)
59
+ return VllmFp8LinearMethod(self, linear_config)
60
+ elif isinstance(layer, FusedMoE):
61
+ raise NotImplementedError(
62
+ "FP8 FusedMoE is currently not supported in torchax-jax")
63
+ return None
64
+
65
+
66
+ class VllmFp8LinearMethod(Fp8LinearMethod):
67
+
68
+ def __init__(self, quant_config: VllmFp8Config,
69
+ jax_config: JaxCommonLinearConfig):
70
+ super().__init__(quant_config)
71
+ self.jax_config = jax_config
72
+ self._configure_sharding()
73
+
74
+ def _configure_sharding(self) -> None:
75
+
76
+ raise NotImplementedError(
77
+ "Configure PartitionSpec for weight_sharding and scale_sharding "
78
+ "based on layer type (RowParallel/ColumnParallel)")
79
+
80
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
81
+
82
+ raise NotImplementedError(
83
+ "Convert layer.weight, layer.weight_scale, and optionally "
84
+ "layer.input_scale and layer.bias from torch tensors to JAX arrays "
85
+ "using torch_to_jax_param() with appropriate sharding")
86
+
87
+ def apply(self,
88
+ layer: torch.nn.Module,
89
+ x: torch.Tensor,
90
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
91
+
92
+ with jax.named_scope(layer._get_name()):
93
+ if self.jax_config.fuse_matmuls:
94
+ out = self._apply_fused(layer, x, bias)
95
+ else:
96
+ out = self._apply_split(layer, x, bias)
97
+
98
+ return out
99
+
100
+ def _apply_fused(self,
101
+ layer: torch.nn.Module,
102
+ x: torch.Tensor,
103
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
104
+
105
+ raise NotImplementedError(
106
+ "Implement single matmul for fused outputs: "
107
+ "quantize input to fp8, perform fp8 matmul with weight and scales, "
108
+ "dequantize output, and add bias if present")
109
+
110
+ def _apply_split(self,
111
+ layer: torch.nn.Module,
112
+ x: torch.Tensor,
113
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
114
+
115
+ raise NotImplementedError(
116
+ "Implement separate matmuls per output partition: "
117
+ "split weight/scale by output_sizes, perform fp8 matmul for each, "
118
+ "concatenate results, and add bias if present")