tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_envs.py +11 -32
- tests/test_utils.py +2 -1
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +2 -9
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
- tpu_inference/layers/common/attention_interface.py +1 -7
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
- tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +10 -43
- tpu_inference/models/jax/llama3.py +1 -2
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +1 -2
- tpu_inference/models/jax/qwen2_5_vl.py +48 -163
- tpu_inference/models/jax/qwen3.py +1 -2
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
- tpu_inference/platforms/tpu_platform.py +31 -37
- tpu_inference/runner/compilation_manager.py +58 -141
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +147 -271
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +13 -36
- tpu_inference/worker/tpu_worker.py +25 -162
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
|
@@ -24,11 +24,9 @@ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
|
|
|
24
24
|
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
25
25
|
is_layer_skipped
|
|
26
26
|
|
|
27
|
-
from tpu_inference import envs
|
|
28
|
-
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
29
27
|
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
30
28
|
get_tpu_quant_method)
|
|
31
|
-
from tpu_inference.layers.vllm.fused_moe import
|
|
29
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
|
|
32
30
|
from tpu_inference.layers.vllm.linear_common import \
|
|
33
31
|
reorder_concatenated_tensor_for_sharding
|
|
34
32
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
@@ -87,14 +85,17 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
87
85
|
fused_mapping=self.packed_modules_mapping,
|
|
88
86
|
):
|
|
89
87
|
return VllmUnquantizedLinearMethod(linear_config)
|
|
88
|
+
# TODO: Add support for MXFP4 Linear Method.
|
|
89
|
+
# MXFP4 LinearMethod is available in AMD-Quark, refer to that
|
|
90
|
+
# implementation if you are interested in enabling MXFP4 here.
|
|
90
91
|
logger.warning_once(
|
|
91
92
|
"MXFP4 linear layer is not implemented - falling back to "
|
|
92
93
|
"UnquantizedLinearMethod.")
|
|
93
94
|
return VllmUnquantizedLinearMethod(linear_config)
|
|
94
95
|
elif isinstance(layer, FusedMoE):
|
|
95
|
-
moe_config
|
|
96
|
-
return VllmMxfp4MoEMethod(moe_config, self.mesh)
|
|
96
|
+
return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
|
|
97
97
|
elif isinstance(layer, Attention):
|
|
98
|
+
# TODO: Add support for MXFP4 Attention.
|
|
98
99
|
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
99
100
|
"Skipping quantization for this layer.")
|
|
100
101
|
return None
|
|
@@ -102,30 +103,13 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
102
103
|
|
|
103
104
|
class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
104
105
|
|
|
105
|
-
def __init__(self,
|
|
106
|
-
moe: FusedMoEConfig,
|
|
107
|
-
mesh: Mesh,
|
|
108
|
-
ep_axis_name: str = 'model'):
|
|
106
|
+
def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
|
|
109
107
|
FusedMoEMethodBase.__init__(self, moe)
|
|
110
108
|
|
|
111
109
|
# We piggyback on triton implementation as it applies minimal hardware
|
|
112
110
|
# specific post processing to the weights.
|
|
113
111
|
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
114
|
-
|
|
115
112
|
self.mesh = mesh
|
|
116
|
-
self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
|
|
117
|
-
self.ep_axis_name = ep_axis_name
|
|
118
|
-
# TODO: Use autotune table once we have it.
|
|
119
|
-
self.block_size = {
|
|
120
|
-
"bt": 64,
|
|
121
|
-
"bf": 1024,
|
|
122
|
-
"bd1": 1536,
|
|
123
|
-
"bd2": 1536,
|
|
124
|
-
"btc": 64,
|
|
125
|
-
"bfc": 1024,
|
|
126
|
-
"bd1c": 1536,
|
|
127
|
-
"bd2c": 1536,
|
|
128
|
-
}
|
|
129
113
|
|
|
130
114
|
def get_fused_moe_quant_config(
|
|
131
115
|
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
@@ -138,7 +122,6 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
138
122
|
|
|
139
123
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
140
124
|
assert isinstance(layer, FusedMoE)
|
|
141
|
-
assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
|
|
142
125
|
|
|
143
126
|
w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
|
|
144
127
|
w13_weight_scale = e8m0_to_fp32(
|
|
@@ -157,8 +140,6 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
157
140
|
w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
|
|
158
141
|
MXFP4_BLOCK_SIZE, jnp.bfloat16)
|
|
159
142
|
|
|
160
|
-
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
161
|
-
|
|
162
143
|
# Because we have dequantized weights, scales are not used anymore.
|
|
163
144
|
delattr(layer, "w13_weight_scale")
|
|
164
145
|
delattr(layer, "w2_weight_scale")
|
|
@@ -176,89 +157,63 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
176
157
|
w3_bias = w13_bias[:, 1::2]
|
|
177
158
|
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
178
159
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
182
|
-
# w2: (num_experts, intermediate_size, hidden_size)
|
|
183
|
-
# Current format:
|
|
184
|
-
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
185
|
-
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
186
|
-
|
|
187
|
-
w13_reshaped = w13_weight.reshape(num_experts, 2,
|
|
188
|
-
intermediate_size, hidden_size)
|
|
189
|
-
|
|
190
|
-
# Transpose non-constracting dim to right most dim
|
|
191
|
-
w13_weight_transposed = jnp.swapaxes(w13_reshaped, 2, 3)
|
|
192
|
-
w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
|
|
193
|
-
|
|
194
|
-
# Apply EP sharding
|
|
195
|
-
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
196
|
-
|
|
160
|
+
# TODO(kyuyeunk): Add weight processing logic for the new kernel.
|
|
161
|
+
if layer.use_ep:
|
|
197
162
|
w13_weight = jax.device_put(
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
163
|
+
w13_weight,
|
|
164
|
+
Format(Layout((0, 1, 2)),
|
|
165
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
166
|
+
w2_weight = jax.device_put(
|
|
167
|
+
w2_weight,
|
|
168
|
+
Format(Layout((0, 1, 2)),
|
|
169
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
170
|
+
|
|
171
|
+
w13_bias = jax.device_put(
|
|
172
|
+
w13_bias,
|
|
173
|
+
Format(Layout((0, 1)),
|
|
174
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
175
|
+
w2_bias = jax.device_put(
|
|
176
|
+
w2_bias,
|
|
177
|
+
Format(Layout((0, 1)),
|
|
178
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
208
179
|
|
|
209
180
|
else:
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
241
|
-
|
|
242
|
-
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
243
|
-
w13_bias,
|
|
244
|
-
output_sizes,
|
|
245
|
-
n_shards,
|
|
246
|
-
dim=1,
|
|
247
|
-
)
|
|
248
|
-
w13_bias = jax.device_put(
|
|
249
|
-
w13_bias,
|
|
250
|
-
Format(Layout((0, 1)),
|
|
251
|
-
NamedSharding(self.mesh, P(None, "model"))))
|
|
252
|
-
w2_bias = jax.device_put(
|
|
253
|
-
w2_bias,
|
|
254
|
-
Format(Layout((0, 1)),
|
|
255
|
-
NamedSharding(self.mesh, P(None, None))))
|
|
181
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
182
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
183
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
184
|
+
n_shards = self.mesh.shape["model"]
|
|
185
|
+
assert intermediate_size % n_shards == 0
|
|
186
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
|
|
187
|
+
output_sizes,
|
|
188
|
+
n_shards,
|
|
189
|
+
dim=1)
|
|
190
|
+
w13_weight = jax.device_put(
|
|
191
|
+
w13_weight,
|
|
192
|
+
Format(Layout((0, 1, 2)),
|
|
193
|
+
NamedSharding(self.mesh, P(None, "model", None))))
|
|
194
|
+
w2_weight = jax.device_put(
|
|
195
|
+
w2_weight,
|
|
196
|
+
Format(Layout((0, 1, 2)),
|
|
197
|
+
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
198
|
+
|
|
199
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
|
|
200
|
+
output_sizes,
|
|
201
|
+
n_shards,
|
|
202
|
+
dim=1)
|
|
203
|
+
w13_bias = jax.device_put(
|
|
204
|
+
w13_bias,
|
|
205
|
+
Format(Layout((0, 1)),
|
|
206
|
+
NamedSharding(self.mesh, P(None, "model"))))
|
|
207
|
+
w2_bias = jax.device_put(
|
|
208
|
+
w2_bias,
|
|
209
|
+
Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
|
|
210
|
+
None))))
|
|
256
211
|
|
|
257
212
|
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
258
213
|
requires_grad=False)
|
|
259
|
-
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
260
|
-
|
|
261
214
|
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
215
|
+
|
|
216
|
+
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
262
217
|
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
263
218
|
|
|
264
219
|
pass
|
|
@@ -291,41 +246,21 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
291
246
|
raise NotImplementedError(
|
|
292
247
|
"Only softmax is supported for scoring_func")
|
|
293
248
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
top_k=top_k,
|
|
311
|
-
ep_axis_name=self.ep_axis_name,
|
|
312
|
-
renormalize_topk_logits=renormalize,
|
|
313
|
-
act_fn=activation,
|
|
314
|
-
**self.block_size,
|
|
315
|
-
)
|
|
316
|
-
else:
|
|
317
|
-
output = fused_moe_func(
|
|
318
|
-
hidden_states=x,
|
|
319
|
-
w1=w13_weight,
|
|
320
|
-
w2=w2_weight,
|
|
321
|
-
w1_bias=w13_bias,
|
|
322
|
-
w2_bias=w2_bias,
|
|
323
|
-
gating_output=gating_output,
|
|
324
|
-
topk=top_k,
|
|
325
|
-
renormalize=renormalize,
|
|
326
|
-
mesh=self.mesh,
|
|
327
|
-
use_ep=layer.use_ep,
|
|
328
|
-
activation=activation,
|
|
329
|
-
)
|
|
249
|
+
# Use the original implementation
|
|
250
|
+
output = fused_moe_func_padded(
|
|
251
|
+
jax_view(x),
|
|
252
|
+
jax_view(layer.w13_weight),
|
|
253
|
+
jax_view(layer.w2_weight),
|
|
254
|
+
jax_view(layer.w13_bias) if self.moe.has_bias else None,
|
|
255
|
+
jax_view(layer.w2_bias) if self.moe.has_bias else None,
|
|
256
|
+
jax_view(router_logits),
|
|
257
|
+
topk=top_k,
|
|
258
|
+
global_num_experts=global_num_experts,
|
|
259
|
+
renormalize=renormalize,
|
|
260
|
+
reduce_results=layer.reduce_results,
|
|
261
|
+
mesh=self.mesh,
|
|
262
|
+
use_ep=layer.use_ep,
|
|
263
|
+
activation=activation,
|
|
264
|
+
)
|
|
330
265
|
|
|
331
266
|
return torch_view(output)
|
|
@@ -25,7 +25,7 @@ from tpu_inference import envs
|
|
|
25
25
|
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
26
26
|
from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
|
|
27
27
|
get_tpu_quant_method)
|
|
28
|
-
from tpu_inference.layers.vllm.fused_moe import
|
|
28
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
|
|
29
29
|
from tpu_inference.layers.vllm.linear_common import (
|
|
30
30
|
reorder_concatenated_tensor_for_sharding,
|
|
31
31
|
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
@@ -108,8 +108,6 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
|
108
108
|
layer: torch.nn.Module,
|
|
109
109
|
x: torch.Tensor,
|
|
110
110
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
111
|
-
assert isinstance(layer, LinearBase)
|
|
112
|
-
|
|
113
111
|
with jax.named_scope(layer._get_name()):
|
|
114
112
|
if in_sharding := self.jax_config.get_input_sharding(x):
|
|
115
113
|
x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
|
|
@@ -168,18 +166,18 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
168
166
|
ep_axis_name: str = 'model'):
|
|
169
167
|
super().__init__(moe)
|
|
170
168
|
self.mesh = mesh
|
|
171
|
-
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
169
|
+
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
172
170
|
self.ep_axis_name = ep_axis_name
|
|
173
171
|
# TODO: Use autotune table once we have it.
|
|
174
172
|
self.block_size = {
|
|
175
|
-
"bt":
|
|
176
|
-
"bf":
|
|
177
|
-
"bd1":
|
|
178
|
-
"bd2":
|
|
179
|
-
"btc":
|
|
180
|
-
"bfc":
|
|
181
|
-
"bd1c":
|
|
182
|
-
"bd2c":
|
|
173
|
+
"bt": 16,
|
|
174
|
+
"bf": 384,
|
|
175
|
+
"bd1": 512,
|
|
176
|
+
"bd2": 512,
|
|
177
|
+
"btc": 16,
|
|
178
|
+
"bfc": 384,
|
|
179
|
+
"bd1c": 256,
|
|
180
|
+
"bd2c": 256,
|
|
183
181
|
}
|
|
184
182
|
|
|
185
183
|
def select_gemm_impl(
|
|
@@ -196,8 +194,6 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
196
194
|
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
197
195
|
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
198
196
|
|
|
199
|
-
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
200
|
-
|
|
201
197
|
if self.moe.has_bias:
|
|
202
198
|
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
203
199
|
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
@@ -216,56 +212,76 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
216
212
|
w3_bias = w13_bias[:, 1::2]
|
|
217
213
|
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
218
214
|
|
|
219
|
-
if self.use_kernel:
|
|
215
|
+
if self.use_kernel and layer.use_ep:
|
|
220
216
|
# Kernel expects:
|
|
221
217
|
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
222
218
|
# w2: (num_experts, intermediate_size, hidden_size)
|
|
223
219
|
# Current format:
|
|
224
220
|
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
225
221
|
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
222
|
+
num_experts = w13_weight.shape[0]
|
|
223
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
224
|
+
hidden_size = w13_weight.shape[2]
|
|
226
225
|
|
|
226
|
+
# Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
|
|
227
227
|
w13_reshaped = w13_weight.reshape(num_experts, 2,
|
|
228
228
|
intermediate_size, hidden_size)
|
|
229
|
+
w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
|
|
229
230
|
|
|
230
|
-
# Transpose
|
|
231
|
-
|
|
232
|
-
w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
|
|
231
|
+
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
|
|
232
|
+
w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
|
|
233
233
|
|
|
234
234
|
# Apply EP sharding
|
|
235
|
-
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
236
|
-
|
|
237
235
|
w13_weight = jax.device_put(
|
|
238
|
-
w13_weight_transposed,
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
236
|
+
w13_weight_transposed,
|
|
237
|
+
Format(Layout((0, 1, 2, 3)),
|
|
238
|
+
NamedSharding(self.mesh, P("model", None, None, None))))
|
|
239
|
+
w2_weight = jax.device_put(
|
|
240
|
+
w2_weight_transposed,
|
|
241
|
+
Format(Layout((0, 1, 2)),
|
|
242
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
242
243
|
|
|
243
244
|
if self.moe.has_bias:
|
|
244
245
|
w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
|
|
246
|
+
|
|
247
|
+
# Apply EP sharding
|
|
245
248
|
w13_bias = jax.device_put(
|
|
246
|
-
w13_bias,
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
249
|
+
w13_bias,
|
|
250
|
+
Format(Layout((0, 1, 2)),
|
|
251
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
252
|
+
w2_bias = jax.device_put(
|
|
253
|
+
w2_bias,
|
|
254
|
+
Format(Layout((0, 1)),
|
|
255
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
250
256
|
|
|
257
|
+
else:
|
|
258
|
+
# Original logic for non-kernel path
|
|
251
259
|
if layer.use_ep:
|
|
252
|
-
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
253
260
|
w13_weight = jax.device_put(
|
|
254
|
-
w13_weight,
|
|
261
|
+
w13_weight,
|
|
262
|
+
Format(Layout((0, 1, 2)),
|
|
263
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
255
264
|
w2_weight = jax.device_put(
|
|
256
|
-
w2_weight,
|
|
265
|
+
w2_weight,
|
|
266
|
+
Format(Layout((0, 1, 2)),
|
|
267
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
257
268
|
|
|
258
269
|
if self.moe.has_bias:
|
|
259
270
|
w13_bias = jax.device_put(
|
|
260
|
-
w13_bias,
|
|
271
|
+
w13_bias,
|
|
272
|
+
Format(Layout((0, 1)),
|
|
273
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
261
274
|
w2_bias = jax.device_put(
|
|
262
|
-
w2_bias,
|
|
275
|
+
w2_bias,
|
|
276
|
+
Format(Layout((0, 1)),
|
|
277
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
263
278
|
|
|
264
279
|
else:
|
|
280
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
281
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
265
282
|
output_sizes = [intermediate_size, intermediate_size]
|
|
266
283
|
n_shards = self.mesh.shape["model"]
|
|
267
284
|
assert intermediate_size % n_shards == 0
|
|
268
|
-
|
|
269
285
|
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
270
286
|
w13_weight, output_sizes, n_shards, dim=1)
|
|
271
287
|
w13_weight = jax.device_put(
|
|
@@ -326,40 +342,30 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
326
342
|
raise NotImplementedError(
|
|
327
343
|
"Only softmax is supported for scoring_func")
|
|
328
344
|
|
|
329
|
-
x = jax_view(x)
|
|
330
|
-
w13_weight = jax_view(layer.w13_weight)
|
|
331
|
-
w2_weight = jax_view(layer.w2_weight)
|
|
332
|
-
w13_bias = w2_bias = None
|
|
333
|
-
if self.moe.has_bias:
|
|
334
|
-
w13_bias = jax_view(layer.w13_bias)
|
|
335
|
-
w2_bias = jax_view(layer.w2_bias)
|
|
336
|
-
gating_output = jax_view(router_logits)
|
|
337
|
-
|
|
338
345
|
if self.use_kernel and layer.use_ep:
|
|
339
346
|
output = fused_ep_moe(
|
|
340
347
|
mesh=self.mesh,
|
|
341
|
-
tokens=x,
|
|
342
|
-
w1=w13_weight,
|
|
343
|
-
w2=w2_weight,
|
|
344
|
-
|
|
345
|
-
b2=w2_bias,
|
|
346
|
-
gating_output=gating_output,
|
|
348
|
+
tokens=jax_view(x),
|
|
349
|
+
w1=jax_view(layer.w13_weight),
|
|
350
|
+
w2=jax_view(layer.w2_weight),
|
|
351
|
+
gating_output=jax_view(router_logits),
|
|
347
352
|
top_k=top_k,
|
|
348
353
|
ep_axis_name=self.ep_axis_name,
|
|
349
|
-
renormalize_topk_logits=renormalize,
|
|
350
|
-
act_fn=activation,
|
|
351
354
|
**self.block_size,
|
|
352
355
|
)
|
|
353
356
|
else:
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
357
|
+
# Use the original implementation
|
|
358
|
+
output = fused_moe_func_padded(
|
|
359
|
+
jax_view(x),
|
|
360
|
+
jax_view(layer.w13_weight),
|
|
361
|
+
jax_view(layer.w2_weight),
|
|
362
|
+
jax_view(layer.w13_bias) if self.moe.has_bias else None,
|
|
363
|
+
jax_view(layer.w2_bias) if self.moe.has_bias else None,
|
|
364
|
+
jax_view(router_logits),
|
|
361
365
|
topk=top_k,
|
|
366
|
+
global_num_experts=global_num_experts,
|
|
362
367
|
renormalize=renormalize,
|
|
368
|
+
reduce_results=layer.reduce_results,
|
|
363
369
|
mesh=self.mesh,
|
|
364
370
|
use_ep=layer.use_ep,
|
|
365
371
|
activation=activation,
|
|
@@ -19,7 +19,6 @@ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
|
|
19
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
20
20
|
ParallelLMHead, VocabParallelEmbedding)
|
|
21
21
|
|
|
22
|
-
from tpu_inference import envs
|
|
23
22
|
from tpu_inference.logger import init_logger
|
|
24
23
|
|
|
25
24
|
P = PartitionSpec
|
|
@@ -212,7 +211,8 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
|
|
|
212
211
|
def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
|
|
213
212
|
if isinstance(tensor, tuple):
|
|
214
213
|
return tuple(_sharded_device_put(t, sharding) for t in tensor)
|
|
215
|
-
|
|
214
|
+
import os
|
|
215
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
216
216
|
if multihost_backend != "ray":
|
|
217
217
|
return jax.device_put(tensor, sharding)
|
|
218
218
|
|
|
@@ -239,6 +239,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|
|
239
239
|
lora_index_to_id: list[Optional[int]],
|
|
240
240
|
max_loras: int,
|
|
241
241
|
vocab_size: int,
|
|
242
|
+
extra_vocab_size: int,
|
|
242
243
|
):
|
|
243
244
|
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
|
244
245
|
# TODO: Should this happen inside mapping internally? If so how can we
|
|
@@ -257,7 +258,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|
|
257
258
|
lora_index_to_id,
|
|
258
259
|
max_loras,
|
|
259
260
|
vocab_size,
|
|
260
|
-
|
|
261
|
+
extra_vocab_size,
|
|
261
262
|
"cpu",
|
|
262
263
|
)
|
|
263
264
|
with torchax.default_env():
|
|
File without changes
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Any, List, Mapping
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class ModelConfig():
|
|
7
|
+
max_model_len: int = 2048
|
|
8
|
+
max_prefill_len: int = 1024
|
|
9
|
+
prefill_batch_size: int = 1
|
|
10
|
+
decode_batch_size: int = 1
|
|
11
|
+
block_size: int = 16
|
|
12
|
+
num_layers: int = 32
|
|
13
|
+
num_kv_heads: int = 32
|
|
14
|
+
head_dim: int = 128
|
|
15
|
+
vocab_size: int = 32000
|
|
16
|
+
model: str = "llama3"
|
|
17
|
+
hf_config: str = ""
|
|
18
|
+
architectures: List[str] = field(default_factory=list)
|
|
19
|
+
override_generation_config: dict[str, Any] = field(default_factory=dict)
|
|
20
|
+
hf_overrides: dict[str, Any] = field(default_factory=dict)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class VllmConfig():
|
|
25
|
+
additional_config: Mapping[str, Any] = field(default_factory=dict)
|
|
26
|
+
# Set default max_model_len to turn off warnings.
|
|
27
|
+
model_config: ModelConfig = field(
|
|
28
|
+
default_factory=lambda: ModelConfig(max_model_len=1024))
|