tpu-inference 0.12.0.dev20251222__py3-none-any.whl → 0.12.0.dev20251224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/core/test_dp_scheduler.py +128 -71
- tests/e2e/test_data_parallel.py +176 -280
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_speculative_decoding.py +26 -6
- tests/layers/jax/test_qwix.py +1 -1
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +36 -21
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +36 -21
- tests/layers/vllm/test_mxfp4.py +25 -10
- tests/layers/vllm/test_unquantized.py +61 -31
- tests/layers/vllm/utils.py +19 -4
- tests/models/common/test_model_loader.py +2 -2
- tests/models/jax/test_qwen2_5_vl.py +10 -11
- tests/runner/test_multimodal_manager.py +3 -3
- tests/runner/test_tpu_runner.py +67 -8
- tests/runner/test_tpu_runner_dp.py +66 -0
- tpu_inference/core/sched/dp_scheduler.py +65 -40
- tpu_inference/kernels/mla/v1/kernel.py +7 -26
- tpu_inference/layers/common/sharding.py +8 -3
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +3 -3
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +3 -3
- tpu_inference/layers/jax/attention/llama4_attention.py +3 -4
- tpu_inference/layers/jax/sample/sampling.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +51 -47
- tpu_inference/layers/vllm/quantization/common.py +14 -13
- tpu_inference/layers/vllm/quantization/mxfp4.py +21 -7
- tpu_inference/layers/vllm/quantization/unquantized.py +19 -7
- tpu_inference/layers/vllm/sharding.py +7 -4
- tpu_inference/models/common/model_loader.py +11 -14
- tpu_inference/models/jax/llama3.py +13 -10
- tpu_inference/models/jax/llama_guard_4.py +1 -1
- tpu_inference/models/jax/qwen2.py +3 -2
- tpu_inference/models/jax/qwen2_5_vl.py +4 -4
- tpu_inference/models/jax/utils/multi_modal_utils.py +4 -4
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
- tpu_inference/models/vllm/vllm_model_wrapper.py +5 -2
- tpu_inference/platforms/tpu_platform.py +7 -7
- tpu_inference/runner/compilation_manager.py +43 -33
- tpu_inference/runner/kv_cache_manager.py +1 -2
- tpu_inference/runner/multimodal_manager.py +1 -1
- tpu_inference/runner/tpu_runner.py +12 -9
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/tpu_worker.py +5 -2
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/METADATA +1 -1
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/RECORD +47 -46
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/top_level.txt +0 -0
|
@@ -25,9 +25,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
|
25
25
|
ReplicatedLinear,
|
|
26
26
|
RowParallelLinear)
|
|
27
27
|
|
|
28
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
28
29
|
from tpu_inference.layers.vllm.linear_common import \
|
|
29
30
|
get_model_matmul_fusion_assignment
|
|
30
|
-
from tpu_inference.utils import TPU_SECOND_LAST_MINOR
|
|
31
|
+
from tpu_inference.utils import TPU_SECOND_LAST_MINOR, get_mesh_shape_product
|
|
31
32
|
|
|
32
33
|
# yapf: enable
|
|
33
34
|
|
|
@@ -49,14 +50,18 @@ class JaxCommonLinearConfig:
|
|
|
49
50
|
self.input_sharding = None
|
|
50
51
|
self.output_sharding = None
|
|
51
52
|
|
|
53
|
+
self.tp_size = get_mesh_shape_product(self.mesh,
|
|
54
|
+
ShardingAxisName.MLP_TENSOR)
|
|
55
|
+
|
|
52
56
|
if isinstance(layer, RowParallelLinear):
|
|
53
|
-
self.weight_sharding = P(None,
|
|
57
|
+
self.weight_sharding = P(None, ShardingAxisName.ATTN_HEAD)
|
|
54
58
|
if self.enable_sp:
|
|
55
|
-
self.output_sharding = P(
|
|
59
|
+
self.output_sharding = P(ShardingAxisName.MLP_TENSOR, None)
|
|
56
60
|
elif isinstance(layer, ColumnParallelLinear):
|
|
57
|
-
self.weight_sharding = P(
|
|
61
|
+
self.weight_sharding = P(ShardingAxisName.ATTN_HEAD, None)
|
|
62
|
+
|
|
58
63
|
if self.enable_sp:
|
|
59
|
-
self.input_sharding = P(
|
|
64
|
+
self.input_sharding = P(ShardingAxisName.MLP_TENSOR, None)
|
|
60
65
|
|
|
61
66
|
if isinstance(layer, MergedColumnParallelLinear) or isinstance(
|
|
62
67
|
layer, QKVParallelLinear):
|
|
@@ -75,18 +80,14 @@ class JaxCommonLinearConfig:
|
|
|
75
80
|
" bad performance.", type(layer))
|
|
76
81
|
|
|
77
82
|
self.bias_sharding = P(self.weight_sharding[0])
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
for axis in self.weight_sharding[0]:
|
|
81
|
-
self.n_shards *= self.mesh.shape.get(axis, 1)
|
|
82
|
-
else:
|
|
83
|
-
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
83
|
+
self.n_shards = get_mesh_shape_product(self.mesh,
|
|
84
|
+
self.weight_sharding[0])
|
|
84
85
|
|
|
85
86
|
def get_input_sharding(self, x: torchax.tensor.Tensor):
|
|
86
87
|
if self.enable_sp:
|
|
87
88
|
token_num = x.shape[0]
|
|
88
89
|
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
89
|
-
if token_num // self.
|
|
90
|
+
if token_num // self.tp_size >= TPU_SECOND_LAST_MINOR:
|
|
90
91
|
return self.input_sharding
|
|
91
92
|
else:
|
|
92
93
|
return None
|
|
@@ -96,7 +97,7 @@ class JaxCommonLinearConfig:
|
|
|
96
97
|
if self.enable_sp:
|
|
97
98
|
token_num = x.shape[0]
|
|
98
99
|
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
99
|
-
if token_num // self.
|
|
100
|
+
if token_num // self.tp_size >= TPU_SECOND_LAST_MINOR:
|
|
100
101
|
return self.output_sharding
|
|
101
102
|
else:
|
|
102
103
|
return None
|
|
@@ -44,12 +44,14 @@ from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
|
44
44
|
get_tpu_quant_method)
|
|
45
45
|
from tpu_inference.layers.common.quantization import (
|
|
46
46
|
dequantize_tensor_from_mxfp4_packed, quantize_tensor)
|
|
47
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
47
48
|
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
48
49
|
from tpu_inference.layers.vllm.linear_common import \
|
|
49
50
|
reorder_concatenated_tensor_for_sharding
|
|
50
51
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
51
52
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
52
53
|
VllmUnquantizedLinearMethod
|
|
54
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
53
55
|
|
|
54
56
|
REQUANTIZED_BLOCK_SIZE = 512
|
|
55
57
|
|
|
@@ -256,7 +258,8 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
256
258
|
w2_bias = jnp.expand_dims(w2_bias, 1)
|
|
257
259
|
|
|
258
260
|
if layer.use_ep:
|
|
259
|
-
ep_sharding = NamedSharding(self.mesh,
|
|
261
|
+
ep_sharding = NamedSharding(self.mesh,
|
|
262
|
+
P(ShardingAxisName.EXPERT))
|
|
260
263
|
|
|
261
264
|
w13_weight = jax.lax.with_sharding_constraint(
|
|
262
265
|
w13_weight, ep_sharding)
|
|
@@ -275,7 +278,8 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
275
278
|
|
|
276
279
|
else:
|
|
277
280
|
output_sizes = [intermediate_size, intermediate_size]
|
|
278
|
-
n_shards =
|
|
281
|
+
n_shards = get_mesh_shape_product(
|
|
282
|
+
self.mesh, ShardingAxisName.MLP_TENSOR)
|
|
279
283
|
assert intermediate_size % n_shards == 0
|
|
280
284
|
|
|
281
285
|
# Reorder w13 weights so that splitting w1 and w3 output
|
|
@@ -301,19 +305,29 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
301
305
|
|
|
302
306
|
w13_weight = jax.lax.with_sharding_constraint(
|
|
303
307
|
w13_weight,
|
|
304
|
-
NamedSharding(
|
|
308
|
+
NamedSharding(
|
|
309
|
+
self.mesh,
|
|
310
|
+
P(None, ShardingAxisName.MLP_TENSOR, None)))
|
|
305
311
|
w2_weight = jax.lax.with_sharding_constraint(
|
|
306
312
|
w2_weight,
|
|
307
|
-
NamedSharding(
|
|
313
|
+
NamedSharding(
|
|
314
|
+
self.mesh,
|
|
315
|
+
P(None, None, ShardingAxisName.MLP_TENSOR)))
|
|
308
316
|
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
309
317
|
w13_weight_scale,
|
|
310
|
-
NamedSharding(
|
|
318
|
+
NamedSharding(
|
|
319
|
+
self.mesh,
|
|
320
|
+
P(None, None, None, ShardingAxisName.MLP_TENSOR)))
|
|
311
321
|
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
312
322
|
w2_weight_scale,
|
|
313
|
-
NamedSharding(
|
|
323
|
+
NamedSharding(
|
|
324
|
+
self.mesh,
|
|
325
|
+
P(None, ShardingAxisName.MLP_TENSOR, None, None)))
|
|
314
326
|
w13_bias = jax.lax.with_sharding_constraint(
|
|
315
327
|
w13_bias,
|
|
316
|
-
NamedSharding(
|
|
328
|
+
NamedSharding(
|
|
329
|
+
self.mesh,
|
|
330
|
+
P(None, None, ShardingAxisName.MLP_TENSOR)))
|
|
317
331
|
w2_bias = jax.lax.with_sharding_constraint(
|
|
318
332
|
w2_bias, NamedSharding(self.mesh, P(None, None, None)))
|
|
319
333
|
|
|
@@ -39,12 +39,14 @@ from tpu_inference import envs
|
|
|
39
39
|
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
40
40
|
from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
|
|
41
41
|
get_tpu_quant_method)
|
|
42
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
42
43
|
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
43
44
|
from tpu_inference.layers.vllm.linear_common import (
|
|
44
45
|
reorder_concatenated_tensor_for_sharding,
|
|
45
46
|
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
46
47
|
from tpu_inference.layers.vllm.quantization.common import (
|
|
47
48
|
JaxCommonConfig, JaxCommonLinearConfig)
|
|
49
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
48
50
|
|
|
49
51
|
P = PartitionSpec
|
|
50
52
|
logger = init_logger(__name__)
|
|
@@ -307,7 +309,8 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
307
309
|
w2_bias = jnp.expand_dims(w2_bias, 1)
|
|
308
310
|
|
|
309
311
|
if layer.use_ep:
|
|
310
|
-
ep_sharding = NamedSharding(self.mesh,
|
|
312
|
+
ep_sharding = NamedSharding(self.mesh,
|
|
313
|
+
P(ShardingAxisName.EXPERT))
|
|
311
314
|
w13_weight = jax.device_put(
|
|
312
315
|
w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
313
316
|
w2_weight = jax.device_put(
|
|
@@ -321,19 +324,26 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
321
324
|
|
|
322
325
|
else:
|
|
323
326
|
output_sizes = [intermediate_size, intermediate_size]
|
|
324
|
-
n_shards = self.mesh
|
|
327
|
+
n_shards = get_mesh_shape_product(self.mesh,
|
|
328
|
+
ShardingAxisName.MLP_TENSOR)
|
|
325
329
|
assert intermediate_size % n_shards == 0
|
|
326
330
|
|
|
327
331
|
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
328
332
|
w13_weight, output_sizes, n_shards, dim=1)
|
|
329
333
|
w13_weight = jax.device_put(
|
|
330
334
|
w13_weight,
|
|
331
|
-
Format(
|
|
332
|
-
|
|
335
|
+
Format(
|
|
336
|
+
Layout((0, 1, 2)),
|
|
337
|
+
NamedSharding(
|
|
338
|
+
self.mesh,
|
|
339
|
+
P(None, ShardingAxisName.MLP_TENSOR, None))))
|
|
333
340
|
w2_weight = jax.device_put(
|
|
334
341
|
w2_weight,
|
|
335
|
-
Format(
|
|
336
|
-
|
|
342
|
+
Format(
|
|
343
|
+
Layout((0, 1, 2)),
|
|
344
|
+
NamedSharding(
|
|
345
|
+
self.mesh,
|
|
346
|
+
P(None, None, ShardingAxisName.MLP_TENSOR))))
|
|
337
347
|
|
|
338
348
|
if self.moe.has_bias:
|
|
339
349
|
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
@@ -343,7 +353,9 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
343
353
|
w13_bias,
|
|
344
354
|
Format(
|
|
345
355
|
Layout((0, 1, 2)),
|
|
346
|
-
NamedSharding(
|
|
356
|
+
NamedSharding(
|
|
357
|
+
self.mesh,
|
|
358
|
+
P(None, None, ShardingAxisName.MLP_TENSOR))))
|
|
347
359
|
w2_bias = jax.device_put(
|
|
348
360
|
w2_bias,
|
|
349
361
|
Format(Layout((0, 1, 2)),
|
|
@@ -34,6 +34,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
|
34
34
|
ParallelLMHead, VocabParallelEmbedding)
|
|
35
35
|
|
|
36
36
|
from tpu_inference import envs
|
|
37
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
37
38
|
from tpu_inference.logger import init_logger
|
|
38
39
|
|
|
39
40
|
P = PartitionSpec
|
|
@@ -123,7 +124,8 @@ def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
|
|
|
123
124
|
def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
|
|
124
125
|
mesh: Mesh) -> None:
|
|
125
126
|
weight = _convert_to_torchax_and_shard(
|
|
126
|
-
layer.weight, NamedSharding(mesh, P(
|
|
127
|
+
layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
|
|
128
|
+
None)))
|
|
127
129
|
layer.weight = Parameter(weight, requires_grad=False)
|
|
128
130
|
|
|
129
131
|
|
|
@@ -132,11 +134,12 @@ def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
|
|
|
132
134
|
# if that config is set, then we should not create new weights but reuse the
|
|
133
135
|
# weight from VocabParallelEmbedding
|
|
134
136
|
weight = _convert_to_torchax_and_shard(
|
|
135
|
-
layer.weight, NamedSharding(mesh, P(
|
|
137
|
+
layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
|
|
138
|
+
None)))
|
|
136
139
|
layer.weight = Parameter(weight, requires_grad=False)
|
|
137
140
|
if layer.bias is not None:
|
|
138
|
-
bias = _convert_to_torchax_and_shard(
|
|
139
|
-
|
|
141
|
+
bias = _convert_to_torchax_and_shard(
|
|
142
|
+
layer.bias, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR)))
|
|
140
143
|
layer.bias = Parameter(bias, requires_grad=False)
|
|
141
144
|
|
|
142
145
|
|
|
@@ -283,10 +283,9 @@ def get_flax_model(
|
|
|
283
283
|
|
|
284
284
|
# Multi-modal support only
|
|
285
285
|
# This function calculates the image token's embeddings by VIT
|
|
286
|
-
def
|
|
287
|
-
**kwargs):
|
|
286
|
+
def run_embed_multimodal(graphdef, state, image_grid_thw, **kwargs):
|
|
288
287
|
model = nnx.merge(graphdef, state)
|
|
289
|
-
return model.
|
|
288
|
+
return model.embed_multimodal(image_grid_thw, **kwargs)
|
|
290
289
|
|
|
291
290
|
embed_sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
292
291
|
# This function will calculates the embeddings of input texts and then merge with the image embeddings
|
|
@@ -294,9 +293,9 @@ def get_flax_model(
|
|
|
294
293
|
jax.jit,
|
|
295
294
|
out_shardings=(embed_sharding),
|
|
296
295
|
)
|
|
297
|
-
def
|
|
296
|
+
def run_embed_input_ids(graphdef, state, *args, **kwargs):
|
|
298
297
|
model = nnx.merge(graphdef, state)
|
|
299
|
-
return model.
|
|
298
|
+
return model.embed_input_ids(*args, **kwargs)
|
|
300
299
|
|
|
301
300
|
# For models that want to work with EAGLE-3 speculative decoding
|
|
302
301
|
@functools.partial(
|
|
@@ -312,10 +311,8 @@ def get_flax_model(
|
|
|
312
311
|
None)
|
|
313
312
|
model_fn = functools.partial(run_model, graphdef)
|
|
314
313
|
compute_logits_fn = functools.partial(run_compute_logits, graphdef)
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
|
|
318
|
-
graphdef)
|
|
314
|
+
embed_multimodal_fn = functools.partial(run_embed_multimodal, graphdef)
|
|
315
|
+
embed_input_ids_fn = functools.partial(run_embed_input_ids, graphdef)
|
|
319
316
|
lora_manager, model = None, None
|
|
320
317
|
combine_hidden_states_fn = functools.partial(combine_hidden_states,
|
|
321
318
|
graphdef)
|
|
@@ -326,8 +323,8 @@ def get_flax_model(
|
|
|
326
323
|
|
|
327
324
|
multimodal_fns = {
|
|
328
325
|
"precompile_vision_encoder_fn": precompile_vision_encoder_fn,
|
|
329
|
-
"
|
|
330
|
-
"
|
|
326
|
+
"embed_multimodal_fn": embed_multimodal_fn,
|
|
327
|
+
"embed_input_ids_fn": embed_input_ids_fn,
|
|
331
328
|
"get_mrope_input_positions_fn": get_mrope_input_positions_fn,
|
|
332
329
|
}
|
|
333
330
|
|
|
@@ -485,14 +482,14 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
485
482
|
)
|
|
486
483
|
|
|
487
484
|
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
|
|
488
|
-
def
|
|
485
|
+
def unimplemented_embed_input_ids(
|
|
489
486
|
self,
|
|
490
487
|
input_ids: "torch.Tensor",
|
|
491
488
|
positions: "torch.Tensor",
|
|
492
489
|
inputs_embeds: Optional["torch.Tensor"] = None,
|
|
493
490
|
) -> "torch.Tensor":
|
|
494
491
|
raise NotImplementedError(
|
|
495
|
-
"This is a JAX model and does not implement the PyTorch
|
|
492
|
+
"This is a JAX model and does not implement the PyTorch embed_input_ids method."
|
|
496
493
|
)
|
|
497
494
|
|
|
498
495
|
# We need a custom __init__ that only calls torch.nn.Module's init,
|
|
@@ -508,7 +505,7 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
508
505
|
{
|
|
509
506
|
"__init__": wrapper_init,
|
|
510
507
|
"forward": unimplemented_forward,
|
|
511
|
-
"
|
|
508
|
+
"embed_input_ids": unimplemented_embed_input_ids,
|
|
512
509
|
# Prevent vLLM from trying to load weights into this dummy class.
|
|
513
510
|
"load_weights": lambda self, *args, **kwargs: None,
|
|
514
511
|
})
|
|
@@ -26,6 +26,7 @@ from tpu_inference import utils
|
|
|
26
26
|
from tpu_inference.distributed.jax_parallel_state import get_pp_group
|
|
27
27
|
from tpu_inference.layers.common.attention_interface import attention
|
|
28
28
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
29
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
29
30
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
30
31
|
from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers
|
|
31
32
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
@@ -34,6 +35,7 @@ from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
|
34
35
|
JaxIntermediateTensors
|
|
35
36
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
36
37
|
load_hf_weights)
|
|
38
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
37
39
|
|
|
38
40
|
logger = init_logger(__name__)
|
|
39
41
|
|
|
@@ -98,7 +100,8 @@ class LlamaAttention(nnx.Module):
|
|
|
98
100
|
self.hidden_size // self.num_heads)
|
|
99
101
|
self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
|
|
100
102
|
|
|
101
|
-
sharding_size = mesh
|
|
103
|
+
sharding_size = get_mesh_shape_product(mesh,
|
|
104
|
+
ShardingAxisName.MLP_TENSOR)
|
|
102
105
|
self.num_heads = utils.get_padded_num_heads(self.num_heads,
|
|
103
106
|
sharding_size)
|
|
104
107
|
self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
|
|
@@ -171,8 +174,8 @@ class LlamaAttention(nnx.Module):
|
|
|
171
174
|
# q_scale = self._q_scale
|
|
172
175
|
k_scale = self._k_scale
|
|
173
176
|
v_scale = self._v_scale
|
|
174
|
-
k, v =
|
|
175
|
-
|
|
177
|
+
k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
|
|
178
|
+
v_scale)
|
|
176
179
|
new_kv_cache, outputs = attention(
|
|
177
180
|
kv_cache,
|
|
178
181
|
q,
|
|
@@ -369,13 +372,13 @@ class LlamaForCausalLM(nnx.Module):
|
|
|
369
372
|
kv_caches: List[jax.Array],
|
|
370
373
|
input_ids: jax.Array,
|
|
371
374
|
attention_metadata: AttentionMetadata,
|
|
372
|
-
_input_embeds,
|
|
373
|
-
_input_positions,
|
|
374
|
-
_layer_name_to_kv_cache,
|
|
375
|
-
_lora_metadata,
|
|
376
|
-
intermediate_tensors: JaxIntermediateTensors,
|
|
377
|
-
_is_first_rank: bool,
|
|
378
|
-
_is_last_rank: bool,
|
|
375
|
+
_input_embeds=None,
|
|
376
|
+
_input_positions=None,
|
|
377
|
+
_layer_name_to_kv_cache=None,
|
|
378
|
+
_lora_metadata=None,
|
|
379
|
+
intermediate_tensors: JaxIntermediateTensors | None = None,
|
|
380
|
+
_is_first_rank: bool | None = None,
|
|
381
|
+
_is_last_rank: bool | None = None,
|
|
379
382
|
*args,
|
|
380
383
|
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
|
|
381
384
|
List[jax.Array], JaxIntermediateTensors]:
|
|
@@ -256,7 +256,7 @@ class LlamaGuard4ForCausalLM(nnx.Module):
|
|
|
256
256
|
self.lm_head.input_embedding_table_DV.value)
|
|
257
257
|
return logits_TV
|
|
258
258
|
|
|
259
|
-
def
|
|
259
|
+
def embed_input_ids(
|
|
260
260
|
self,
|
|
261
261
|
input_ids: jax.Array,
|
|
262
262
|
multimodal_embeddings: Optional[List[jax.Array]] = None
|
|
@@ -24,6 +24,7 @@ from vllm.config import VllmConfig
|
|
|
24
24
|
from tpu_inference import utils
|
|
25
25
|
from tpu_inference.layers.common.attention_interface import attention
|
|
26
26
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
27
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
27
28
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
28
29
|
from tpu_inference.logger import init_logger
|
|
29
30
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
@@ -166,8 +167,8 @@ class Qwen2Attention(nnx.Module):
|
|
|
166
167
|
# q_scale = self._q_scale
|
|
167
168
|
k_scale = self._k_scale
|
|
168
169
|
v_scale = self._v_scale
|
|
169
|
-
k, v =
|
|
170
|
-
|
|
170
|
+
k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
|
|
171
|
+
v_scale)
|
|
171
172
|
new_kv_cache, outputs = attention(
|
|
172
173
|
kv_cache,
|
|
173
174
|
q,
|
|
@@ -1010,9 +1010,9 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
1010
1010
|
split_indices = np.cumsum(sizes)[:-1]
|
|
1011
1011
|
return tuple(jnp.split(image_embeds, split_indices))
|
|
1012
1012
|
|
|
1013
|
-
def
|
|
1014
|
-
|
|
1015
|
-
|
|
1013
|
+
def embed_multimodal(self, image_grid_thw: tuple[tuple[int, int, int],
|
|
1014
|
+
...],
|
|
1015
|
+
**kwargs: object) -> MultiModalEmbeddings:
|
|
1016
1016
|
|
|
1017
1017
|
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
|
|
1018
1018
|
image_grid_thw, **kwargs)
|
|
@@ -1036,7 +1036,7 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
1036
1036
|
|
|
1037
1037
|
return multimodal_embeddings
|
|
1038
1038
|
|
|
1039
|
-
def
|
|
1039
|
+
def embed_input_ids(
|
|
1040
1040
|
self, input_ids: jax.Array,
|
|
1041
1041
|
multimodal_embeddings: Optional[jax.Array]) -> jax.Array:
|
|
1042
1042
|
|
|
@@ -43,25 +43,25 @@ def sanity_check_mm_encoder_outputs(
|
|
|
43
43
|
) -> None:
|
|
44
44
|
"""
|
|
45
45
|
Perform sanity checks for the result of
|
|
46
|
-
[`vllm.model_executor.models.SupportsMultiModal.
|
|
46
|
+
[`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][].
|
|
47
47
|
"""
|
|
48
48
|
assert isinstance(mm_embeddings, (list, tuple, jax.Array)), (
|
|
49
49
|
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
|
|
50
50
|
f"or a single 3D tensor, but got {type(mm_embeddings)} "
|
|
51
51
|
"instead. This is most likely due to incorrect implementation "
|
|
52
|
-
"of the model's `
|
|
52
|
+
"of the model's `embed_multimodal` method.")
|
|
53
53
|
|
|
54
54
|
assert len(mm_embeddings) == expected_num_items, (
|
|
55
55
|
"Expected number of multimodal embeddings to match number of "
|
|
56
56
|
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
|
|
57
57
|
"instead. This is most likely due to incorrect implementation "
|
|
58
|
-
"of the model's `
|
|
58
|
+
"of the model's `embed_multimodal` method.")
|
|
59
59
|
|
|
60
60
|
assert all(e.ndim == 2 for e in mm_embeddings), (
|
|
61
61
|
"Expected multimodal embeddings to be a sequence of 2D tensors, "
|
|
62
62
|
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
|
63
63
|
"instead. This is most likely due to incorrect implementation "
|
|
64
|
-
"of the model's `
|
|
64
|
+
"of the model's `embed_multimodal` method.")
|
|
65
65
|
|
|
66
66
|
|
|
67
67
|
def flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
|
|
@@ -35,7 +35,7 @@ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
|
|
|
35
35
|
DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
|
|
36
36
|
DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
|
|
37
37
|
|
|
38
|
-
|
|
38
|
+
DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG = {
|
|
39
39
|
"qwix": {
|
|
40
40
|
"use_abstract_model":
|
|
41
41
|
True,
|
|
@@ -452,7 +452,7 @@ def get_default_qwix_quantization_config(
|
|
|
452
452
|
# NOTE (jacobplatin): we'll default to mixed FP8 (attention) + FP4 (MoE experts)
|
|
453
453
|
# for DeepSeek
|
|
454
454
|
if model_type == "deepseek_v3" and quant_method == "fp8":
|
|
455
|
-
config = copy.deepcopy(
|
|
455
|
+
config = copy.deepcopy(DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG)
|
|
456
456
|
|
|
457
457
|
# Dynamically fetch block size from HF config if available
|
|
458
458
|
# Config fmt: 'weight_block_size': [1, 512] -> we want the 2nd dim for tile_size
|
|
@@ -462,7 +462,7 @@ def get_default_qwix_quantization_config(
|
|
|
462
462
|
block_size = hf_quant_config["weight_block_size"]
|
|
463
463
|
if isinstance(block_size, (list, tuple)) and len(block_size) == 2:
|
|
464
464
|
assert block_size[
|
|
465
|
-
0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}!"
|
|
465
|
+
0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}! If you are trying to run quantized DeepSeek, we currently only support 1D-subchannel quantization and those models can be found here: https://huggingface.co/collections/jrplatin/deepseek-r1-1d-subchannel"
|
|
466
466
|
tile_size = block_size[1]
|
|
467
467
|
assert tile_size > 1, f"Expected tile_size > 1 for DeepSeek, but got {tile_size}"
|
|
468
468
|
logger.info(
|
|
@@ -37,6 +37,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
|
|
|
37
37
|
from vllm.sequence import IntermediateTensors
|
|
38
38
|
|
|
39
39
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
40
41
|
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
41
42
|
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
|
|
42
43
|
from tpu_inference.logger import init_logger
|
|
@@ -234,8 +235,10 @@ class VllmModelWrapper:
|
|
|
234
235
|
|
|
235
236
|
@functools.partial(
|
|
236
237
|
jax.jit,
|
|
237
|
-
out_shardings=(NamedSharding(
|
|
238
|
-
|
|
238
|
+
out_shardings=(NamedSharding(
|
|
239
|
+
self.mesh,
|
|
240
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
241
|
+
ShardingAxisName.MLP_TENSOR))),
|
|
239
242
|
)
|
|
240
243
|
def compute_logits_func(
|
|
241
244
|
params_and_buffers: Any,
|
|
@@ -168,12 +168,12 @@ class TpuPlatform(Platform):
|
|
|
168
168
|
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
169
169
|
if not multihost_backend: # Single host
|
|
170
170
|
if parallel_config.pipeline_parallel_size == 1:
|
|
171
|
-
logger.info("Force using UniProcExecutor for JAX on
|
|
172
|
-
|
|
171
|
+
logger.info("Force using UniProcExecutor for JAX on "
|
|
172
|
+
"single host without pipeline parallelism.")
|
|
173
173
|
parallel_config.distributed_executor_backend = "uni"
|
|
174
174
|
else:
|
|
175
|
-
logger.info("Force using MultiprocExecutor for JAX on
|
|
176
|
-
|
|
175
|
+
logger.info("Force using MultiprocExecutor for JAX on "
|
|
176
|
+
"single host with pipeline parallelism.")
|
|
177
177
|
parallel_config.distributed_executor_backend = "mp"
|
|
178
178
|
elif multihost_backend == "ray":
|
|
179
179
|
from tpu_inference.executors.ray_distributed_executor import \
|
|
@@ -189,9 +189,9 @@ class TpuPlatform(Platform):
|
|
|
189
189
|
|
|
190
190
|
if scheduler_config.is_multimodal_model and not \
|
|
191
191
|
scheduler_config.disable_chunked_mm_input:
|
|
192
|
-
logger.warning("TPU does not support running Multimodal models"
|
|
193
|
-
|
|
194
|
-
|
|
192
|
+
logger.warning("TPU does not support running Multimodal models"
|
|
193
|
+
" without setting `--disable_chunked_mm_input`. "
|
|
194
|
+
"Forcing --disable_chunked_mm_input.")
|
|
195
195
|
scheduler_config.disable_chunked_mm_input = True
|
|
196
196
|
|
|
197
197
|
kv_transfer_config = vllm_config.kv_transfer_config
|
|
@@ -127,7 +127,7 @@ class CompilationManager:
|
|
|
127
127
|
|
|
128
128
|
self._run_compilation(
|
|
129
129
|
"input_embeddings_merger",
|
|
130
|
-
self.runner.
|
|
130
|
+
self.runner.embed_input_ids_fn,
|
|
131
131
|
self.runner.state,
|
|
132
132
|
dummy_input_ids,
|
|
133
133
|
dummy_multimodal_embeddings,
|
|
@@ -136,7 +136,7 @@ class CompilationManager:
|
|
|
136
136
|
|
|
137
137
|
self._run_compilation(
|
|
138
138
|
"input_embeddings_merger_text_only",
|
|
139
|
-
self.runner.
|
|
139
|
+
self.runner.embed_input_ids_fn,
|
|
140
140
|
self.runner.state,
|
|
141
141
|
dummy_input_ids,
|
|
142
142
|
None,
|
|
@@ -495,35 +495,37 @@ class CompilationManager:
|
|
|
495
495
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
496
496
|
logits_sharding)
|
|
497
497
|
for do_sampling in (True, False):
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
498
|
+
for logprobs in (True, False):
|
|
499
|
+
if do_sampling:
|
|
500
|
+
temperature = np.full((num_reqs, ),
|
|
501
|
+
0.7,
|
|
502
|
+
dtype=np.float32)
|
|
503
|
+
top_k = np.full((num_reqs, ), 20, dtype=np.int32)
|
|
504
|
+
top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
|
|
505
|
+
(temperature, top_k, top_p) = device_array(
|
|
506
|
+
self.runner.mesh, (temperature, top_k, top_p),
|
|
507
|
+
sharding=sampling_metadata_sharding)
|
|
508
|
+
else:
|
|
509
|
+
temperature = None
|
|
510
|
+
top_k = None
|
|
511
|
+
top_p = None
|
|
512
|
+
|
|
513
|
+
sampling_metadata = TPUSupportedSamplingMetadata(
|
|
514
|
+
temperature=temperature,
|
|
515
|
+
top_k=top_k,
|
|
516
|
+
top_p=top_p,
|
|
517
|
+
do_sampling=do_sampling,
|
|
518
|
+
logprobs=logprobs)
|
|
519
|
+
self._run_compilation(
|
|
520
|
+
f"worker{self.runner.rank} sample",
|
|
521
|
+
sample,
|
|
522
|
+
self.runner.rng_params_for_sampling,
|
|
523
|
+
self.runner.mesh,
|
|
524
|
+
logits,
|
|
525
|
+
sampling_metadata,
|
|
526
|
+
num_reqs=num_reqs,
|
|
527
|
+
do_sampling=do_sampling,
|
|
528
|
+
)
|
|
527
529
|
|
|
528
530
|
self._sampling_precompiled = True
|
|
529
531
|
|
|
@@ -555,8 +557,16 @@ class CompilationManager:
|
|
|
555
557
|
logger.info("Compiling gather_logprobs with different input shapes.")
|
|
556
558
|
hsize = self.runner.model_config.get_vocab_size()
|
|
557
559
|
for num_reqs in self.runner.num_reqs_paddings:
|
|
558
|
-
|
|
559
|
-
|
|
560
|
+
logits_sharding = NamedSharding(
|
|
561
|
+
self.runner.mesh,
|
|
562
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
563
|
+
ShardingAxisName.MLP_TENSOR))
|
|
564
|
+
token_ids_sharding = NamedSharding(
|
|
565
|
+
self.runner.mesh, PartitionSpec(ShardingAxisName.MLP_DATA, ))
|
|
566
|
+
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
567
|
+
logits_sharding)
|
|
568
|
+
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32,
|
|
569
|
+
token_ids_sharding)
|
|
560
570
|
self._run_compilation(
|
|
561
571
|
f"worker{self.runner.rank} gather_logprobs",
|
|
562
572
|
self.runner._compute_and_gather_logprobs,
|