tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,386 @@
1
+ from typing import Any, Callable, Optional, Union
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ from jax.experimental.layout import Format, Layout
7
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
+ from torch.nn.parameter import Parameter
9
+ from torchax.interop import jax_view, torch_view
10
+ from torchax.ops.mappings import t2j
11
+ from vllm.attention.layer import Attention
12
+ from vllm.logger import init_logger
13
+ from vllm.model_executor.layers.fused_moe.layer import (
14
+ FusedMoE, FusedMoEConfig, UnquantizedFusedMoEMethod)
15
+ from vllm.model_executor.layers.fused_moe.modular_kernel import (
16
+ FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
17
+ from vllm.model_executor.layers.linear import (LinearBase,
18
+ UnquantizedLinearMethod)
19
+ from vllm.model_executor.layers.quantization import \
20
+ register_quantization_config
21
+ from vllm.model_executor.layers.quantization.base_config import (
22
+ QuantizationConfig, QuantizeMethodBase)
23
+
24
+ from tpu_inference import envs
25
+ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
26
+ from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
27
+ get_tpu_quant_method)
28
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
29
+ from tpu_inference.layers.vllm.linear_common import (
30
+ reorder_concatenated_tensor_for_sharding,
31
+ slice_sharded_tensor_for_concatenation, torch_to_jax_param)
32
+ from tpu_inference.layers.vllm.quantization.common import (
33
+ JaxCommonConfig, JaxCommonLinearConfig)
34
+
35
+ P = PartitionSpec
36
+ logger = init_logger(__name__)
37
+
38
+
39
+ @register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
40
+ class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
41
+
42
+ @classmethod
43
+ def get_name(cls) -> str:
44
+ return UNQUANTIZED
45
+
46
+ @classmethod
47
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
48
+ return [torch.float32, torch.float16, torch.bfloat16]
49
+
50
+ @classmethod
51
+ def get_min_capability(cls) -> int:
52
+ return 0 # Always supported
53
+
54
+ @classmethod
55
+ def get_config_filenames(cls) -> list[str]:
56
+ return [] # No extra configs required.
57
+
58
+ @classmethod
59
+ def from_config(cls, _: dict[str, Any]) -> "VllmUnquantizedConfig":
60
+ return cls()
61
+
62
+ def get_quant_method(self, layer: torch.nn.Module,
63
+ prefix: str) -> Optional[QuantizeMethodBase]:
64
+ if isinstance(layer, LinearBase):
65
+ linear_config = self.get_linear_config(layer)
66
+ return VllmUnquantizedLinearMethod(linear_config)
67
+ if isinstance(layer, FusedMoE):
68
+ moe_config = self.get_moe_config(layer)
69
+ return VllmUnquantizedFusedMoEMethod(moe_config, self.mesh)
70
+ if isinstance(layer, Attention):
71
+ return None
72
+ return None
73
+
74
+
75
+ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
76
+
77
+ def __init__(self, jax_config: JaxCommonLinearConfig):
78
+ self.jax_config = jax_config
79
+
80
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
81
+ weight = torch_to_jax_param(
82
+ layer.weight,
83
+ NamedSharding(self.jax_config.mesh,
84
+ self.jax_config.weight_sharding),
85
+ self.jax_config.output_sizes,
86
+ self.jax_config.n_shards,
87
+ self.jax_config.fuse_matmuls,
88
+ )
89
+ delattr(layer, "weight")
90
+ layer.weight = weight
91
+
92
+ if layer.bias is not None and not layer.skip_bias_add:
93
+ if layer.return_bias:
94
+ logger.warning_once("Bias might return incorrect value.")
95
+
96
+ bias = torch_to_jax_param(
97
+ layer.bias,
98
+ NamedSharding(self.jax_config.mesh,
99
+ self.jax_config.bias_sharding),
100
+ self.jax_config.output_sizes,
101
+ self.jax_config.n_shards,
102
+ self.jax_config.fuse_matmuls,
103
+ )
104
+ delattr(layer, "bias")
105
+ layer.bias = bias
106
+
107
+ def apply(self,
108
+ layer: torch.nn.Module,
109
+ x: torch.Tensor,
110
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
111
+ with jax.named_scope(layer._get_name()):
112
+ if in_sharding := self.jax_config.get_input_sharding(x):
113
+ x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
114
+
115
+ if self.jax_config.fuse_matmuls:
116
+ out = self._apply_fused(layer, x, bias)
117
+ else:
118
+ out = self._apply_split(layer, x, bias)
119
+
120
+ if out_sharding := self.jax_config.get_output_sharding(out):
121
+ out.shard_(NamedSharding(self.jax_config.mesh, out_sharding))
122
+
123
+ return out
124
+
125
+ def _apply_fused(self,
126
+ layer: torch.nn.Module,
127
+ x: torch.Tensor,
128
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
129
+ x_jax = jax_view(x)
130
+ weight_jax = jax_view(layer.weight)
131
+
132
+ outs = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
133
+ if bias is not None and not layer.skip_bias_add:
134
+ outs += bias.jax()
135
+
136
+ outs = slice_sharded_tensor_for_concatenation(
137
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
138
+ out = jnp.concatenate(outs, axis=-1)
139
+ return torch_view(out)
140
+
141
+ def _apply_split(self,
142
+ layer: torch.nn.Module,
143
+ x: torch.Tensor,
144
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
145
+ assert isinstance(layer.weight, torch.nn.ParameterList)
146
+
147
+ x_jax = x.jax()
148
+ outs = []
149
+ for i, weight in enumerate(layer.weight):
150
+ weight_jax = jax_view(weight)
151
+
152
+ out = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
153
+ if bias is not None and not layer.skip_bias_add:
154
+ out += jax_view(bias[i])
155
+
156
+ outs.append(out)
157
+ out = jnp.concatenate(outs, axis=-1)
158
+ return torch_view(out)
159
+
160
+
161
+ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
162
+
163
+ def __init__(self,
164
+ moe: FusedMoEConfig,
165
+ mesh: Mesh,
166
+ ep_axis_name: str = 'model'):
167
+ super().__init__(moe)
168
+ self.mesh = mesh
169
+ self.use_kernel = envs.USE_MOE_EP_KERNEL
170
+ self.ep_axis_name = ep_axis_name
171
+ # TODO: Use autotune table once we have it.
172
+ self.block_size = {
173
+ "bt": 16,
174
+ "bf": 384,
175
+ "bd1": 512,
176
+ "bd2": 512,
177
+ "btc": 16,
178
+ "bfc": 384,
179
+ "bd1c": 256,
180
+ "bd2c": 256,
181
+ }
182
+
183
+ def select_gemm_impl(
184
+ self,
185
+ prepare_finalize: FusedMoEPrepareAndFinalize,
186
+ moe: FusedMoEConfig,
187
+ layer: torch.nn.Module,
188
+ ) -> FusedMoEPermuteExpertsUnpermute:
189
+ raise NotImplementedError(
190
+ "Selecting gemm implementation is currently not supported.")
191
+
192
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
193
+ assert isinstance(layer, FusedMoE)
194
+ available_devices = self.mesh.devices.flatten()
195
+ with jax.default_device(available_devices[0]):
196
+ w13_weight = t2j(layer.w13_weight, use_dlpack=False)
197
+ w2_weight = t2j(layer.w2_weight, use_dlpack=False)
198
+
199
+ if self.moe.has_bias:
200
+ w13_bias = t2j(layer.w13_bias, use_dlpack=False)
201
+ w2_bias = t2j(layer.w2_bias, use_dlpack=False)
202
+
203
+ if layer.activation == "swigluoai":
204
+ # When using swigluoai, vLLM splits gmm output in a interleaved way.
205
+ # However, interleaved split is not performant on TPU. Therefore,
206
+ # we preprocess the weight so that splitting gmm output by middle
207
+ # can still get the same result.
208
+ w1_weight = w13_weight[:, ::2, :]
209
+ w3_weight = w13_weight[:, 1::2, :]
210
+ w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
211
+
212
+ if self.moe.has_bias:
213
+ w1_bias = w13_bias[:, ::2]
214
+ w3_bias = w13_bias[:, 1::2]
215
+ w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
216
+
217
+ if self.use_kernel and layer.use_ep:
218
+ # Kernel expects:
219
+ # w13: (num_experts, 2, hidden_size, intermediate_size)
220
+ # w2: (num_experts, intermediate_size, hidden_size)
221
+ # Current format:
222
+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
223
+ # w2_weight: (num_experts, hidden_size, intermediate_size)
224
+ num_experts = w13_weight.shape[0]
225
+ intermediate_size = w13_weight.shape[1] // 2
226
+ hidden_size = w13_weight.shape[2]
227
+
228
+ # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
229
+ w13_reshaped = w13_weight.reshape(num_experts, 2,
230
+ intermediate_size,
231
+ hidden_size)
232
+ w13_weight_transposed = jnp.transpose(w13_reshaped,
233
+ (0, 1, 3, 2))
234
+
235
+ # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
236
+ w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
237
+
238
+ # Apply EP sharding
239
+ w13_weight = jax.device_put(
240
+ w13_weight_transposed,
241
+ Format(
242
+ Layout((0, 1, 2, 3)),
243
+ NamedSharding(self.mesh, P("model", None, None,
244
+ None))))
245
+ w2_weight = jax.device_put(
246
+ w2_weight_transposed,
247
+ Format(Layout((0, 1, 2)),
248
+ NamedSharding(self.mesh, P("model", None, None))))
249
+
250
+ if self.moe.has_bias:
251
+ w13_bias = w13_bias.reshape(num_experts, 2,
252
+ intermediate_size)
253
+
254
+ # Apply EP sharding
255
+ w13_bias = jax.device_put(
256
+ w13_bias,
257
+ Format(
258
+ Layout((0, 1, 2)),
259
+ NamedSharding(self.mesh, P("model", None, None))))
260
+ w2_bias = jax.device_put(
261
+ w2_bias,
262
+ Format(Layout((0, 1)),
263
+ NamedSharding(self.mesh, P("model", None))))
264
+
265
+ else:
266
+ # Original logic for non-kernel path
267
+ if layer.use_ep:
268
+ w13_weight = jax.device_put(
269
+ w13_weight,
270
+ Format(
271
+ Layout((0, 1, 2)),
272
+ NamedSharding(self.mesh, P("model", None, None))))
273
+ w2_weight = jax.device_put(
274
+ w2_weight,
275
+ Format(
276
+ Layout((0, 1, 2)),
277
+ NamedSharding(self.mesh, P("model", None, None))))
278
+
279
+ if self.moe.has_bias:
280
+ w13_bias = jax.device_put(
281
+ w13_bias,
282
+ Format(Layout((0, 1)),
283
+ NamedSharding(self.mesh, P("model", None))))
284
+ w2_bias = jax.device_put(
285
+ w2_bias,
286
+ Format(Layout((0, 1)),
287
+ NamedSharding(self.mesh, P("model", None))))
288
+
289
+ else:
290
+ intermediate_size = w13_weight.shape[1] // 2
291
+ assert intermediate_size == w2_weight.shape[-1]
292
+ output_sizes = [intermediate_size, intermediate_size]
293
+ n_shards = self.mesh.shape["model"]
294
+ assert intermediate_size % n_shards == 0
295
+ w13_weight = reorder_concatenated_tensor_for_sharding(
296
+ w13_weight, output_sizes, n_shards, dim=1)
297
+ w13_weight = jax.device_put(
298
+ w13_weight,
299
+ Format(
300
+ Layout((0, 1, 2)),
301
+ NamedSharding(self.mesh, P(None, "model", None))))
302
+ w2_weight = jax.device_put(
303
+ w2_weight,
304
+ Format(
305
+ Layout((0, 1, 2)),
306
+ NamedSharding(self.mesh, P(None, None, "model"))))
307
+
308
+ if self.moe.has_bias:
309
+ w13_bias = reorder_concatenated_tensor_for_sharding(
310
+ w13_bias, output_sizes, n_shards, dim=1)
311
+ w13_bias = jax.device_put(
312
+ w13_bias,
313
+ Format(Layout((0, 1)),
314
+ NamedSharding(self.mesh, P(None, "model"))))
315
+ w2_bias = jax.device_put(
316
+ w2_bias,
317
+ Format(Layout((0, 1)),
318
+ NamedSharding(self.mesh, P(None, None))))
319
+
320
+ layer.w13_weight = Parameter(torch_view(w13_weight),
321
+ requires_grad=False)
322
+ layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
323
+
324
+ if self.moe.has_bias:
325
+ layer.w13_bias = Parameter(torch_view(w13_bias),
326
+ requires_grad=False)
327
+ layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
328
+
329
+ def apply(
330
+ self,
331
+ layer: torch.nn.Module,
332
+ x: torch.Tensor,
333
+ router_logits: torch.Tensor,
334
+ top_k: int,
335
+ renormalize: bool,
336
+ use_grouped_topk: bool = False,
337
+ topk_group: Optional[int] = None,
338
+ num_expert_group: Optional[int] = None,
339
+ global_num_experts: int = -1,
340
+ expert_map: Optional[torch.Tensor] = None,
341
+ custom_routing_function: Optional[Callable] = None,
342
+ scoring_func: str = "softmax",
343
+ routed_scaling_factor: float = 1.0,
344
+ e_score_correction_bias: Optional[torch.Tensor] = None,
345
+ apply_router_weight_on_input: bool = False,
346
+ activation: str = "silu",
347
+ enable_eplb: bool = False,
348
+ expert_load_view: Optional[torch.Tensor] = None,
349
+ logical_to_physical_map: Optional[torch.Tensor] = None,
350
+ logical_replica_count: Optional[torch.Tensor] = None,
351
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
352
+ assert isinstance(layer, FusedMoE)
353
+ if scoring_func != "softmax":
354
+ raise NotImplementedError(
355
+ "Only softmax is supported for scoring_func")
356
+
357
+ if self.use_kernel and layer.use_ep:
358
+ output = fused_ep_moe(
359
+ mesh=self.mesh,
360
+ tokens=jax_view(x),
361
+ w1=jax_view(layer.w13_weight),
362
+ w2=jax_view(layer.w2_weight),
363
+ gating_output=jax_view(router_logits),
364
+ top_k=top_k,
365
+ ep_axis_name=self.ep_axis_name,
366
+ **self.block_size,
367
+ )
368
+ else:
369
+ # Use the original implementation
370
+ output = fused_moe_func_padded(
371
+ jax_view(x),
372
+ jax_view(layer.w13_weight),
373
+ jax_view(layer.w2_weight),
374
+ jax_view(layer.w13_bias) if self.moe.has_bias else None,
375
+ jax_view(layer.w2_bias) if self.moe.has_bias else None,
376
+ jax_view(router_logits),
377
+ topk=top_k,
378
+ global_num_experts=global_num_experts,
379
+ renormalize=renormalize,
380
+ reduce_results=layer.reduce_results,
381
+ mesh=self.mesh,
382
+ use_ep=layer.use_ep,
383
+ activation=activation,
384
+ )
385
+
386
+ return torch_view(output)
@@ -0,0 +1,230 @@
1
+ import os
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ import torchax
7
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
+ from torch.nn import Parameter
9
+ from torch.utils import _pytree as pytree
10
+ from torchax.interop import jax_view, torch_view
11
+ from torchax.ops.mappings import t2j
12
+ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
13
+ MergedColumnParallelLinearWithLoRA,
14
+ MergedQKVParallelLinearWithLoRA,
15
+ QKVParallelLinearWithLoRA,
16
+ ReplicatedLinearWithLoRA,
17
+ RowParallelLinearWithLoRA)
18
+ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
19
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
20
+ ParallelLMHead, VocabParallelEmbedding)
21
+
22
+ from tpu_inference.logger import init_logger
23
+
24
+ P = PartitionSpec
25
+
26
+ logger = init_logger(__name__)
27
+
28
+ TORCH_TO_JAX_DTYPE_MAP = {
29
+ torch.float32: jnp.float32,
30
+ torch.float16: jnp.float16,
31
+ torch.bfloat16: jnp.bfloat16,
32
+ }
33
+
34
+
35
+ def shard_model_to_tpu(model: torch.nn.Module,
36
+ mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
37
+ """
38
+ Shard the model weights and move them to TPU.
39
+ At the same time, also turn the weight tensors into torchax tensors so that
40
+ jax code can interop with it and the overall program can be traced and
41
+ compiled in XLA.
42
+ Args:
43
+ model: A PyTorch model whose weights are on CPU main memory.
44
+ mesh: JAX mesh object for sharding.
45
+ Returns:
46
+ Dictionary of parameters and buffers that will be used as arguments of
47
+ torch.func.functional_call
48
+ """
49
+
50
+ with jax.default_device(jax.devices("cpu")[0]):
51
+ _shard_module_to_tpu(model, mesh)
52
+
53
+ params, buffers = _extract_all_params_buffers(model)
54
+
55
+ # For other weight tensors, repliate them on all the TPU chips.
56
+ params, buffers = pytree.tree_map_only(
57
+ _tensor_is_in_cpu,
58
+ lambda tensor: _shard_tensor_to_tpu_replicated(tensor, mesh),
59
+ (params, buffers))
60
+
61
+ return {**params, **buffers}
62
+
63
+
64
+ def update_lora(model: torch.nn.Module,
65
+ initial_params_buffers) -> dict[str, torchax.torch.Tensor]:
66
+ params, buffers = _extract_all_params_buffers(model)
67
+ params_buffers = {**params, **buffers}
68
+ for k, v in params_buffers.items():
69
+ if 'lora_a_stacked' in k or 'lora_b_stacked' in k:
70
+ assert k in initial_params_buffers, f"{k} not in initial_params_buffers"
71
+ initial_params_buffers[k] = v
72
+
73
+ return initial_params_buffers
74
+
75
+
76
+ def _extract_all_params_buffers(model: torch.nn.Module):
77
+ return dict(model.named_parameters()), dict(model.named_buffers())
78
+
79
+
80
+ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
81
+ # Check if a tensor haven't been converted to torchax tensor.
82
+ if not isinstance(tensor, torchax.tensor.Tensor):
83
+ return True
84
+ # Check if torchax tensor is still in CPU.
85
+ return tensor.jax_device == jax.devices('cpu')[0]
86
+
87
+
88
+ def _convert_to_torchax_and_shard(tensor: torch.Tensor,
89
+ sharding: NamedSharding) -> torch.Tensor:
90
+ if os.getenv("VLLM_TPU_USING_PATHWAYS", False) and isinstance(
91
+ tensor, torch.Tensor):
92
+ np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
93
+ dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
94
+ return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))
95
+ else:
96
+ if isinstance(tensor, torchax.tensor.Tensor):
97
+ tensor = jax_view(tensor)
98
+ else:
99
+ tensor = t2j(tensor)
100
+ return torch_view(_sharded_device_put(tensor, sharding))
101
+
102
+
103
+ def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
104
+ mesh: Mesh) -> torchax.tensor.Tensor:
105
+ return _convert_to_torchax_and_shard(tensor, NamedSharding(mesh, P()))
106
+
107
+
108
+ def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
109
+ mesh: Mesh) -> None:
110
+ weight = _convert_to_torchax_and_shard(
111
+ layer.weight, NamedSharding(mesh, P('model', None)))
112
+ layer.weight = Parameter(weight, requires_grad=False)
113
+
114
+
115
+ def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
116
+ # TODO(qihqi): currently this is not handling case of tie_word_weights=True.
117
+ # if that config is set, then we should not create new weights but reuse the
118
+ # weight from VocabParallelEmbedding
119
+ weight = _convert_to_torchax_and_shard(
120
+ layer.weight, NamedSharding(mesh, P('model', None)))
121
+ layer.weight = Parameter(weight, requires_grad=False)
122
+ if layer.bias is not None:
123
+ bias = _convert_to_torchax_and_shard(layer.bias,
124
+ NamedSharding(mesh, P('model')))
125
+ layer.bias = Parameter(bias, requires_grad=False)
126
+
127
+
128
+ def _shard_base_linear_lora_replicated(layer: BaseLinearLayerWithLoRA,
129
+ mesh: Mesh) -> None:
130
+ # NOTE: lora_a_stacked[i] has shape [max_loras, 1, num_out, num_in]
131
+ sharded_lora_a_tpu = torch.nn.ParameterList()
132
+ sharded_lora_b_tpu = torch.nn.ParameterList()
133
+
134
+ for i in range(layer.n_slices):
135
+ sharded_lora_a_tpu.append(
136
+ _shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh))
137
+ sharded_lora_b_tpu.append(
138
+ _shard_tensor_to_tpu_replicated(layer.lora_b_stacked[i], mesh))
139
+
140
+ layer.lora_a_stacked = sharded_lora_a_tpu
141
+ layer.lora_b_stacked = sharded_lora_b_tpu
142
+
143
+
144
+ def _shard_column_linear_lora(layer: ColumnParallelLinearWithLoRA,
145
+ mesh: Mesh) -> None:
146
+ assert layer.n_slices > 0, "layer.n_slices should be greater than 0"
147
+ # lora_a_stacked[i] has shape [max_loras, 1, max_lora_rank, in_features]
148
+ sharded_lora_a_tpu = torch.nn.ParameterList()
149
+ sharded_lora_b_tpu = torch.nn.ParameterList()
150
+
151
+ # lora_b_stacked[i] has shape [max_loras, 1, out_features, max_lora_rank]
152
+ lora_b_partition_spec = P(None, None, 'model', None)
153
+ lora_b_sharding = NamedSharding(mesh, lora_b_partition_spec)
154
+ for i in range(layer.n_slices):
155
+ sharded_lora_a_tpu.append(
156
+ _shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh))
157
+
158
+ sharded_lora_b_tpu.append(
159
+ _convert_to_torchax_and_shard(layer.lora_b_stacked[i],
160
+ lora_b_sharding))
161
+
162
+ layer.lora_a_stacked = sharded_lora_a_tpu
163
+ layer.lora_b_stacked = sharded_lora_b_tpu
164
+
165
+
166
+ def _shard_qkv_linear_lora(layer: ColumnParallelLinearWithLoRA,
167
+ mesh: Mesh) -> None:
168
+ _shard_column_linear_lora(layer, mesh)
169
+
170
+
171
+ def _shard_merged_column_parallel_linear_lora(
172
+ layer: MergedColumnParallelLinearWithLoRA, mesh: Mesh) -> None:
173
+ _shard_column_linear_lora(layer, mesh)
174
+
175
+
176
+ def _shard_merged_qkv_parallel_linear_lora(
177
+ layer: MergedQKVParallelLinearWithLoRA, mesh: Mesh) -> None:
178
+ _shard_column_linear_lora(layer, mesh)
179
+
180
+
181
+ def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
182
+ mesh: Mesh) -> None:
183
+ _shard_base_linear_lora_replicated(layer, mesh)
184
+
185
+
186
+ # NOTE: Ordering is important as it calls first matched type of a given module
187
+ MODULE_TYPE_TO_SHARDING_FUNC = [
188
+ # Shard embedding layers
189
+ (ParallelLMHead, _shard_lm_head),
190
+ (VocabParallelEmbedding, _shard_vocab_parallel_embedding),
191
+ # Shard LoRA layers
192
+ (ColumnParallelLinearWithLoRA, _shard_column_linear_lora),
193
+ (QKVParallelLinearWithLoRA, _shard_qkv_linear_lora),
194
+ (MergedColumnParallelLinearWithLoRA,
195
+ _shard_merged_column_parallel_linear_lora),
196
+ (MergedQKVParallelLinearWithLoRA, _shard_merged_qkv_parallel_linear_lora),
197
+ (RowParallelLinearWithLoRA, _shard_row_parallel_linear_lora),
198
+ (ReplicatedLinearWithLoRA, _shard_base_linear_lora_replicated),
199
+ ]
200
+
201
+
202
+ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
203
+ for path, module in model.named_modules():
204
+ for module_type, sharding_func in MODULE_TYPE_TO_SHARDING_FUNC:
205
+ if type(module) is module_type:
206
+ logger.debug("shard %s with %s", path, sharding_func)
207
+ sharding_func(module, mesh)
208
+ break
209
+
210
+
211
+ def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
212
+ if isinstance(tensor, tuple):
213
+ return tuple(_sharded_device_put(t, sharding) for t in tensor)
214
+ import os
215
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
216
+ if multihost_backend != "ray":
217
+ return jax.device_put(tensor, sharding)
218
+
219
+ # NOTE: at here, num_global_devices != num_local_devices
220
+ # meaning we are in multi-host setup. Each host will run the same process
221
+ # and each process only need to handle the devices accessible to this host.
222
+ shape = tensor.shape
223
+ x_split = [
224
+ jax.device_put(tensor[i], device) for device, i in
225
+ sharding.addressable_devices_indices_map(shape).items()
226
+ ]
227
+ return jax.make_array_from_single_device_arrays(shape,
228
+ sharding,
229
+ x_split,
230
+ dtype=tensor.dtype)
@@ -0,0 +1,10 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from vllm.logger import _VllmLogger
4
+ from vllm.logger import init_logger as init_vllm_logger
5
+
6
+
7
+ def init_logger(name: str) -> _VllmLogger:
8
+ # Prepend the root "vllm" to the module path to use vllm's configured logger.
9
+ patched_name = "vllm." + name
10
+ return init_vllm_logger(patched_name)
File without changes