tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,203 +1,199 @@
1
- from typing import Callable, Optional, Union
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
2
14
 
3
15
  import jax
4
- import jax.numpy as jnp
5
16
  import torch
6
- import torch.nn.functional as F
7
- from jax.experimental.layout import Format, Layout
8
- from jax.sharding import Mesh, NamedSharding
9
- from jax.sharding import PartitionSpec as P
17
+ from compressed_tensors.quantization import QuantizationArgs
18
+ from jax.sharding import Mesh
10
19
  from torch.nn.parameter import Parameter
11
- from torchax.interop import call_jax, torch_view
20
+ from torchax.interop import torch_view
12
21
  from torchax.ops.mappings import t2j
13
- from vllm.logger import init_logger
14
22
  from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
15
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
16
- CompressedTensorsConfig
17
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \
18
- CompressedTensorsW8A8Fp8MoEMethod
19
- from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
20
- WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
21
-
22
- from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
23
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
24
+ CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)
25
+
26
+ from tpu_inference.layers.common.sharding import ShardingAxisName
27
+ from tpu_inference.layers.vllm.fused_moe import (FusedMoEBackend,
28
+ fused_moe_apply,
29
+ select_moe_backend)
30
+ from tpu_inference.layers.vllm.process_weights.fused_moe_weights import (
31
+ FusedMoEWeights, process_moe_weights, shard_moe_weights)
32
+ from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
33
+ from tpu_inference.layers.vllm.quantization.unquantized import \
34
+ VllmUnquantizedFusedMoEMethod
35
+ from tpu_inference.logger import init_logger
36
+ from tpu_inference.utils import get_mesh_shape_product
23
37
 
24
38
  logger = init_logger(__name__)
25
39
 
26
40
 
41
+ class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):
42
+
43
+ @staticmethod
44
+ def get_moe_method(
45
+ quant_config: "VllmCompressedTensorsConfig", # type: ignore # noqa E501
46
+ layer: torch.nn.Module,
47
+ layer_name: str,
48
+ ) -> CompressedTensorsMoEMethod:
49
+ assert isinstance(layer, FusedMoE)
50
+
51
+ # FusedMoE was made by combining multiple Linears so need to
52
+ # make sure quantization config for Linear can target it
53
+ quant_config._add_fused_moe_to_target_scheme_map()
54
+ unfused_names = [
55
+ layer_name + proj_name
56
+ for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
57
+ ]
58
+ # TODO: refactor this to use expert_mapping and check all layer numbers
59
+ all_scheme_dicts = [
60
+ quant_config.get_scheme_dict(layer, name) for name in unfused_names
61
+ ]
62
+ scheme_dict = all_scheme_dicts.pop()
63
+
64
+ # multiple schemes found
65
+ if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
66
+ raise ValueError("All MoE projections need to have same "
67
+ "quantization scheme but found multiple")
68
+
69
+ if scheme_dict is None:
70
+ return VllmUnquantizedFusedMoEMethod(layer.moe_config,
71
+ quant_config.mesh)
72
+
73
+ weight_quant = scheme_dict.get("weights")
74
+ input_quant = scheme_dict.get("input_activations")
75
+
76
+ if quant_config._is_fp8_w8a8(weight_quant, input_quant):
77
+ return VllmCompressedTensorsW8A8Fp8MoEMethod(
78
+ weight_quant, input_quant, layer.moe_config, quant_config.mesh)
79
+ else:
80
+ raise RuntimeError(
81
+ f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
82
+
83
+
27
84
  class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
28
- JaxCommonConfig):
85
+ VllmQuantConfig):
86
+
87
+ def __init__(
88
+ self,
89
+ weight_quant: QuantizationArgs,
90
+ input_quant: QuantizationArgs,
91
+ moe: FusedMoEConfig,
92
+ mesh: Mesh,
93
+ ):
94
+ super().__init__(weight_quant, input_quant, moe)
29
95
 
30
- def __init__(self, quant_config: "CompressedTensorsConfig",
31
- moe: FusedMoEConfig, mesh: Mesh):
32
- super().__init__(quant_config, moe)
33
96
  self.mesh = mesh
34
- self.quant_config = quant_config
97
+ self.moe_backend = select_moe_backend(self.moe)
35
98
 
36
- # disable GPU paths
37
- self.use_marlin = False
38
- self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
39
- self.is_fp8_w8a8_sm100 = False
40
- self.use_cutlass = False
41
- self.disable_expert_map = False
99
+ self.extra_backend_kwargs = {}
100
+ if self.moe_backend == FusedMoEBackend.FUSED_MOE:
101
+ raise NotImplementedError(
102
+ "Per-channel quantization is not supported in FusedMoE kernel."
103
+ )
42
104
 
43
105
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
106
+ """
107
+ Docstring for process_weights_after_loading
108
+
109
+ :param self: Description
110
+ :param layer: Description
111
+ :type layer: torch.nn.Module
112
+
113
+ Steps:
114
+ 1. Read weights from layer object and convert to jax arrays
115
+ 2. Interleave concat w13 weights
116
+ 3. Shard weights for tp (rowwise w13, colwise w2)
117
+ 4. Initialize Params as torch.nn.Parameter
118
+ a. w13_weight - float8_e4m3fn shape: (num_experts, 2 x intermediate_size, input_size)
119
+ b. w2_weight - float8_e4m3fn shape: (num_experts, output_size, intermediate_size)
120
+ c. w13_weight_scale - FP32 shape: (num_experts, 2 x intermediate_size, 1)
121
+ d. w2_weight_scale - FP32shape: (num_experts, output_size, 1)
122
+ """
44
123
  assert isinstance(layer, FusedMoE)
45
124
 
46
- intermediate_size = layer.w13_weight.shape[1] // 2
47
- w1_weight = layer.w13_weight[:, :intermediate_size]
48
- w3_weight = layer.w13_weight[:, intermediate_size:]
49
- w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size]
50
- w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:]
51
-
125
+ w13_weight = t2j(layer.w13_weight, use_dlpack=False)
126
+ w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
52
127
  w2_weight = t2j(layer.w2_weight, use_dlpack=False)
53
- w2_weight_scale = t2j(layer.w2_weight_scale.to(torch.bfloat16),
54
- use_dlpack=False)
55
- w1_weight = t2j(w1_weight, use_dlpack=False)
56
- w1_weight_scale = t2j(w1_weight_scale.to(torch.bfloat16),
57
- use_dlpack=False)
58
- w3_weight = t2j(w3_weight, use_dlpack=False)
59
- w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16),
60
- use_dlpack=False)
61
-
62
- if layer.use_ep:
63
- format = Format(Layout((0, 1, 2)),
64
- NamedSharding(self.mesh, P("model", None, None)))
65
- w1_weight = jax.device_put(w1_weight, format)
66
- w1_weight_scale = jax.device_put(w1_weight_scale, format)
67
- w3_weight = jax.device_put(w3_weight, format)
68
- w3_weight_scale = jax.device_put(w3_weight_scale, format)
69
- w2_weight = jax.device_put(w2_weight, format)
70
- w2_weight_scale = jax.device_put(w2_weight_scale, format)
128
+ w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
129
+
130
+ if self.moe.has_bias:
131
+ w13_bias = t2j(layer.w13_bias, use_dlpack=False)
132
+ w2_bias = t2j(layer.w2_bias, use_dlpack=False)
71
133
  else:
72
- assert intermediate_size == w2_weight.shape[-1]
73
- n_shards = self.mesh.shape["model"]
74
- assert intermediate_size % n_shards == 0
75
-
76
- # TODO: enable this if using fused weights
77
- # output_sizes = [intermediate_size, intermediate_size]
78
- # w13_weight = reorder_concatenated_tensor_for_sharding(
79
- # w13_weight, output_sizes, n_shards, dim=1
80
- # )
81
-
82
- w13_format = Format(
83
- Layout((0, 1, 2)),
84
- NamedSharding(self.mesh, P(None, "model", None)))
85
- w1_weight = jax.device_put(w1_weight, w13_format)
86
- w1_weight_scale = jax.device_put(w1_weight_scale, w13_format)
87
- w3_weight = jax.device_put(w3_weight, w13_format)
88
- w3_weight_scale = jax.device_put(w3_weight_scale, w13_format)
89
- w2_weight = jax.device_put(
90
- w2_weight,
91
- Format(Layout((0, 1, 2)),
92
- NamedSharding(self.mesh, P(None, None, "model"))),
134
+ w13_bias = w2_bias = None
135
+
136
+ @jax.jit
137
+ def process_fp8_moe_weights(
138
+ w13_weight: jax.Array,
139
+ w13_weight_scale: jax.Array,
140
+ w13_bias: jax.Array | None,
141
+ w2_weight: jax.Array,
142
+ w2_weight_scale: jax.Array,
143
+ w2_bias: jax.Array | None,
144
+ ) -> FusedMoEWeights:
145
+ w13_interleave = layer.activation == "swigluoai"
146
+ w13_reorder_size = get_mesh_shape_product(
147
+ self.mesh, ShardingAxisName.MLP_TENSOR)
148
+
149
+ return process_moe_weights(
150
+ weights=FusedMoEWeights(
151
+ w13_weight=w13_weight,
152
+ w13_weight_scale=w13_weight_scale,
153
+ w13_bias=w13_bias,
154
+ w2_weight=w2_weight,
155
+ w2_weight_scale=w2_weight_scale,
156
+ w2_bias=w2_bias,
157
+ ),
158
+ moe_backend=self.moe_backend,
159
+ w13_reorder_size=w13_reorder_size,
160
+ w13_interleave=w13_interleave,
93
161
  )
94
- w2_weight_scale = jax.device_put(
95
- w2_weight_scale,
96
- Format(Layout((0, 1, 2)), NamedSharding(self.mesh, P())),
97
- ) # replicate
98
-
99
- w1_weight = Parameter(torch_view(w1_weight), requires_grad=False)
100
- w1_weight_scale = Parameter(torch_view(w1_weight_scale),
101
- requires_grad=False)
102
- w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
103
- w2_weight_scale = Parameter(torch_view(w2_weight_scale),
104
- requires_grad=False)
105
- w3_weight = Parameter(torch_view(w3_weight), requires_grad=False)
106
- w3_weight_scale = Parameter(torch_view(w3_weight_scale),
107
- requires_grad=False)
108
-
109
- # TODO dont reuse variable
110
- layer.w13_weight = w1_weight
111
- layer.w13_weight_scale = w1_weight_scale
112
- layer.w2_weight = w2_weight
113
- layer.w2_weight_scale = w2_weight_scale
114
- layer.w3_weight = w3_weight
115
- layer.w3_weight_scale = w3_weight_scale
162
+
163
+ weights = process_fp8_moe_weights(
164
+ w13_weight,
165
+ w13_weight_scale,
166
+ w13_bias,
167
+ w2_weight,
168
+ w2_weight_scale,
169
+ w2_bias,
170
+ )
171
+ weights = torch_view(
172
+ shard_moe_weights(weights, self.moe_backend, self.mesh))
173
+
174
+ layer.w13_weight = Parameter(weights.w13_weight, requires_grad=False)
175
+ layer.w2_weight = Parameter(weights.w2_weight, requires_grad=False)
176
+
177
+ layer.w13_weight_scale = Parameter(weights.w13_weight_scale,
178
+ requires_grad=False)
179
+ layer.w2_weight_scale = Parameter(weights.w2_weight_scale,
180
+ requires_grad=False)
181
+
182
+ if self.moe.has_bias:
183
+ layer.w13_bias = Parameter(weights.w13_bias, requires_grad=False)
184
+ layer.w2_bias = Parameter(weights.w2_bias, requires_grad=False)
116
185
 
117
186
  def apply(
118
187
  self,
119
188
  layer: torch.nn.Module,
120
189
  x: torch.Tensor,
121
190
  router_logits: torch.Tensor,
122
- top_k: int,
123
- renormalize: bool,
124
- use_grouped_topk: bool = False,
125
- topk_group: Optional[int] = None,
126
- num_expert_group: Optional[int] = None,
127
- global_num_experts: int = -1,
128
- expert_map: Optional[torch.Tensor] = None,
129
- custom_routing_function: Optional[Callable] = None,
130
- scoring_func: str = "softmax",
131
- routed_scaling_factor: float = 1.0,
132
- e_score_correction_bias: Optional[torch.Tensor] = None,
133
- apply_router_weight_on_input: bool = False,
134
- activation: str = "silu",
135
- enable_eplb: bool = False,
136
- expert_load_view: Optional[torch.Tensor] = None,
137
- logical_to_physical_map: Optional[torch.Tensor] = None,
138
- logical_replica_count: Optional[torch.Tensor] = None,
139
- ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
140
- assert isinstance(layer, FusedMoE)
141
- if activation != "silu":
142
- raise NotImplementedError(
143
- "Only silu is supported for activation function.")
144
- if scoring_func != "softmax":
145
- raise NotImplementedError(
146
- "Only softmax is supported for scoring_func")
147
-
148
- # import sys
149
- # sys.stdin = open(0)
150
- # breakpoint()
151
-
152
- # TODO: Use MoE kernel when it supports fp8
153
-
154
- seqlen = x.shape[0]
155
-
156
- expert_weights = F.softmax(router_logits, dim=-1)
157
- expert_weights, expert_indices = torch.topk(expert_weights,
158
- top_k,
159
- dim=-1)
160
- if renormalize:
161
- expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
162
-
163
- # cond ffn
164
- # e = total num of exp = 160
165
- # t = seqlen
166
- # o = config.imtermediate size
167
- # i = config.dim
168
- #torch.einsum("ti, eoi -> teo", x, layer.w13_weight) * self.w13_weight_scale)
169
- ux1 = call_jax(jax.lax.dot,
170
- x,
171
- layer.w13_weight,
172
- dimension_numbers=(((1, ), (2, )), ((), ())),
173
- preferred_element_type=jnp.bfloat16.dtype)
174
- x1 = F.silu(ux1 * layer.w13_weight_scale.squeeze(2))
175
-
176
- #x3 = torch.einsum("ti, eoi -> teo", x, layer.w3_weight) * self.w3_weight_scale
177
- x3 = call_jax(jax.lax.dot,
178
- x,
179
- layer.w3_weight,
180
- dimension_numbers=(((1, ), (2, )), ((), ())),
181
- preferred_element_type=jnp.bfloat16.dtype
182
- ) * layer.w3_weight_scale.squeeze(2)
183
-
184
- #expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2_weight) * self.w2_weight_scale
185
- expert_outs = call_jax(
186
- jax.lax.dot,
187
- x1 * x3,
188
- layer.w2_weight,
189
- dimension_numbers=(((2, ), (2, )), ((1, ), (0, ))),
190
- preferred_element_type=jnp.bfloat16.dtype).transpose(
191
- 0, 1) * layer.w2_weight_scale.squeeze(2)
192
-
193
- seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
194
- expert_outs = expert_outs[seq_indexes, expert_indices]
195
-
196
- # out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
197
- out = call_jax(jax.lax.dot,
198
- expert_outs,
199
- expert_weights,
200
- dimension_numbers=(((1, ), (1, )), ((0, ), (0, ))),
201
- preferred_element_type=jnp.bfloat16.dtype)
202
-
203
- return out
191
+ ) -> torch.Tensor:
192
+ return fused_moe_apply(
193
+ layer,
194
+ x,
195
+ router_logits,
196
+ self.moe_backend,
197
+ self.mesh,
198
+ self.extra_backend_kwargs,
199
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional
2
16
 
3
17
  import jax
@@ -6,49 +20,27 @@ import torch
6
20
  from compressed_tensors.quantization import (QuantizationArgs,
7
21
  QuantizationStrategy)
8
22
  from jax.sharding import NamedSharding, PartitionSpec
23
+ from torch.nn.parameter import Parameter
9
24
  from torchax.interop import jax_view, torch_view
10
25
  from torchax.ops.mappings import t2j
11
26
  from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
12
27
  CompressedTensorsW8A8Fp8
13
- from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
14
- per_tensor_dequantize
15
28
 
16
- from tpu_inference.layers.vllm.linear_common import (
17
- sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
18
- torch_to_jax_param)
19
- from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
29
+ from tpu_inference.layers.common.quantization import (dequantize_tensor,
30
+ quantize_tensor)
31
+ from tpu_inference.layers.common.utils import \
32
+ slice_sharded_tensor_for_concatenation
33
+ from tpu_inference.layers.vllm.linear import sharded_quantized_matmul
34
+ from tpu_inference.layers.vllm.process_weights.linear_weights import (
35
+ LinearWeights, process_lienar_weights, shard_linear_weights,
36
+ to_parameter_list)
37
+ from tpu_inference.layers.vllm.quantization.configs import \
38
+ VllmQuantLinearConfig
39
+ from tpu_inference.logger import init_logger
20
40
 
21
41
  P = PartitionSpec
22
42
 
23
-
24
- def requantize_with_max_scale(
25
- weight: torch.Tensor, weight_scale: torch.Tensor,
26
- logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
27
- dtype = weight.dtype
28
- dtype_info = torch.finfo(dtype)
29
- maxval = float(dtype_info.max)
30
- minval = float(dtype_info.min)
31
-
32
- max_w_scale = weight_scale.max()
33
-
34
- unfused_module_in_checkpoint = (weight_scale[-1]
35
- > torch.finfo(torch.float8_e4m3fn).min)
36
-
37
- # If unfused checkpoint, need requanize with the single scale.
38
- if unfused_module_in_checkpoint:
39
- start = 0
40
- for idx, logical_width in enumerate(logical_widths):
41
- # Skip any component with zero width.
42
- if logical_width == 0:
43
- continue
44
- end = start + logical_width
45
- weight_dq = per_tensor_dequantize(weight[start:end, :],
46
- weight_scale[idx])
47
- weight_q = weight_dq / max_w_scale
48
- weight[start:end, :] = weight_q.clamp(minval, maxval).to(dtype)
49
- start = end
50
-
51
- return max_w_scale, weight
43
+ logger = init_logger(__name__)
52
44
 
53
45
 
54
46
  class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
@@ -57,15 +49,86 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
57
49
  self,
58
50
  weight_quant: QuantizationArgs,
59
51
  is_static_input_scheme: bool,
60
- jax_config: JaxCommonLinearConfig,
52
+ linear_config: VllmQuantLinearConfig,
61
53
  ):
62
54
  super().__init__(weight_quant, is_static_input_scheme)
63
55
 
64
- self.jax_config = jax_config
56
+ self.linear_config = linear_config
65
57
 
66
58
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
67
- weight = layer.weight
68
- weight_scale = layer.weight_scale
59
+ weight = t2j(layer.weight, use_dlpack=False)
60
+ delattr(layer, "weight")
61
+ weight_scale = t2j(layer.weight_scale, use_dlpack=False)
62
+ delattr(layer, "weight_scale")
63
+
64
+ if layer.bias is not None and not layer.skip_bias_add:
65
+ if layer.return_bias:
66
+ logger.warning_once("Bias might return incorrect value.")
67
+ bias = t2j(layer.bias, use_dlpack=False)
68
+ delattr(layer, "bias")
69
+ else:
70
+ bias = None
71
+
72
+ per_tensor = self.strategy == QuantizationStrategy.TENSOR
73
+
74
+ @jax.jit
75
+ def process_fp8_linear_weights(
76
+ weight: jax.Array,
77
+ weight_scale: jax.Array,
78
+ bias: jax.Array | None,
79
+ ) -> LinearWeights:
80
+ if per_tensor:
81
+ weights = []
82
+ start = 0
83
+ # Multiple weights may have been concatenated. Loop through
84
+ # each weight and perform dequantization.
85
+ for i, output_size in enumerate(
86
+ self.linear_config.output_sizes):
87
+ end = start + output_size
88
+ weights.append(
89
+ dequantize_tensor(weight[start:end], weight_scale[i]))
90
+ start = end
91
+ weight = jnp.concat(weights, axis=0)
92
+ # Requantize into per-tensor.
93
+ weight, weight_scale = quantize_tensor(jnp.float8_e4m3fn,
94
+ weight, None)
95
+ else:
96
+ weight_scale = jnp.squeeze(weight_scale, -1)
97
+
98
+ return process_lienar_weights(
99
+ LinearWeights(
100
+ weight=weight,
101
+ weight_scale=weight_scale,
102
+ zero_point=None,
103
+ bias=bias,
104
+ ),
105
+ fused=self.linear_config.fuse_matmuls,
106
+ output_sizes=self.linear_config.output_sizes,
107
+ reorder_size=self.linear_config.n_shards,
108
+ per_tensor=per_tensor,
109
+ )
110
+
111
+ weights = process_fp8_linear_weights(weight, weight_scale, bias)
112
+ weights = torch_view(
113
+ shard_linear_weights(
114
+ weights,
115
+ mesh=self.linear_config.mesh,
116
+ weight_p_spec=self.linear_config.weight_sharding,
117
+ bias_p_spec=self.linear_config.bias_sharding,
118
+ per_tensor=per_tensor,
119
+ ))
120
+
121
+ if self.linear_config.fuse_matmuls:
122
+ layer.weight = Parameter(weights.weight, requires_grad=False)
123
+ layer.weight_scale = Parameter(weights.weight_scale,
124
+ requires_grad=False)
125
+ if bias is not None:
126
+ layer.bias = Parameter(weights.bias, requires_grad=False)
127
+ else:
128
+ layer.weight = to_parameter_list(weights.weight)
129
+ layer.weight_scale = to_parameter_list(weights.weight_scale)
130
+ if bias is not None:
131
+ layer.bias = to_parameter_list(weights.bias)
69
132
 
70
133
  if self.is_static_input_scheme:
71
134
  # In static quant, all input_scales share the same value.
@@ -74,59 +137,16 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
74
137
 
75
138
  input_scale = jax.device_put(
76
139
  t2j(input_scale_first, use_dlpack=False),
77
- NamedSharding(self.jax_config.mesh, P()))
140
+ NamedSharding(self.linear_config.mesh, P()))
78
141
  input_scale = torch.nn.Parameter(torch_view(input_scale),
79
142
  requires_grad=False)
80
143
  delattr(layer, "input_scale")
81
144
  layer.input_scale = input_scale
82
145
 
83
- # TODO(kyuyeunk): Investigate performance gain from merging scales.
84
- # By merging input and weight scales, we reduce the number of muls
85
- # required for dequantization from 2 (for each scales) to 1.
86
- # weight_scale *= input_scale_first
87
-
88
- if self.strategy == QuantizationStrategy.TENSOR:
89
- weight_scale, weight = requantize_with_max_scale(
90
- weight, weight_scale, self.jax_config.output_sizes)
91
- weight_scale = jax.device_put(
92
- t2j(weight_scale, use_dlpack=False),
93
- NamedSharding(self.jax_config.mesh, P()))
94
- weight_scale = torch.nn.Parameter(torch_view(weight_scale),
95
- requires_grad=False)
96
- else:
97
- weight_scale = weight_scale.squeeze(-1)
98
- weight_scale = torch_to_jax_param(
99
- weight_scale,
100
- NamedSharding(self.jax_config.mesh,
101
- self.jax_config.bias_sharding),
102
- self.jax_config.output_sizes, self.jax_config.n_shards,
103
- self.jax_config.fuse_matmuls)
104
- delattr(layer, "weight_scale")
105
- layer.weight_scale = weight_scale
106
-
107
- weight = torch_to_jax_param(
108
- layer.weight,
109
- NamedSharding(self.jax_config.mesh,
110
- self.jax_config.weight_sharding),
111
- self.jax_config.output_sizes, self.jax_config.n_shards,
112
- self.jax_config.fuse_matmuls)
113
- delattr(layer, "weight")
114
- layer.weight = weight
115
-
116
- if layer.bias is not None:
117
- bias = torch_to_jax_param(
118
- layer.bias,
119
- NamedSharding(self.jax_config.mesh,
120
- self.jax_config.bias_sharding),
121
- self.jax_config.output_sizes, self.jax_config.n_shards,
122
- self.jax_config.fuse_matmuls)
123
- delattr(layer, "bias")
124
- layer.bias = bias
125
-
126
146
  def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
127
147
  bias: Optional[torch.Tensor]) -> torch.Tensor:
128
148
  with jax.named_scope(layer._get_name()):
129
- if self.jax_config.fuse_matmuls:
149
+ if self.linear_config.fuse_matmuls:
130
150
  return self._apply_fused(layer, x, bias)
131
151
  else:
132
152
  return self._apply_split(layer, x, bias)
@@ -157,13 +177,13 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
157
177
  else:
158
178
  outs = sharded_quantized_matmul(x_jax, weight_jax,
159
179
  weight_scale_jax,
160
- self.jax_config.mesh,
161
- self.jax_config.weight_sharding)
180
+ self.linear_config.mesh,
181
+ self.linear_config.weight_sharding)
162
182
 
163
183
  if bias is not None and not layer.skip_bias_add:
164
184
  outs += jax_view(bias)
165
185
  outs = slice_sharded_tensor_for_concatenation(
166
- outs, self.jax_config.output_sizes, self.jax_config.n_shards)
186
+ outs, self.linear_config.output_sizes, self.linear_config.n_shards)
167
187
  return torch_view(jnp.concatenate(outs, axis=-1))
168
188
 
169
189
  def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
@@ -197,10 +217,10 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
197
217
  out *= weight_scale_jax * input_scale
198
218
  out = out.astype(x_jax.dtype)
199
219
  else:
200
- out = sharded_quantized_matmul(x_jax, weight_jax,
201
- weight_scale_jax,
202
- self.jax_config.mesh,
203
- self.jax_config.weight_sharding)
220
+ out = sharded_quantized_matmul(
221
+ x_jax, weight_jax, weight_scale_jax,
222
+ self.linear_config.mesh,
223
+ self.linear_config.weight_sharding)
204
224
 
205
225
  if bias is not None and not layer.skip_bias_add:
206
226
  out += jax_view(bias[i])