tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +2 -1
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +43 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
- tpu_inference/platforms/tpu_platform.py +15 -2
- tpu_inference/runner/compilation_manager.py +58 -33
- tpu_inference/runner/kv_cache_manager.py +9 -3
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +203 -102
- tpu_inference/spec_decode/jax/eagle3.py +19 -2
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +160 -23
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import time
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import jax
|
|
6
6
|
import jax.numpy as jnp
|
|
@@ -135,12 +135,6 @@ class CompilationManager:
|
|
|
135
135
|
ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
|
|
136
136
|
|
|
137
137
|
# Keep existing pattern for complex array operations
|
|
138
|
-
block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
|
|
139
|
-
block_tables = block_tables.reshape(-1)
|
|
140
|
-
block_tables = device_array(self.runner.mesh,
|
|
141
|
-
block_tables,
|
|
142
|
-
sharding=dp_sharding)
|
|
143
|
-
|
|
144
138
|
seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
|
|
145
139
|
jnp.int32, dp_sharding)
|
|
146
140
|
query_start_loc = self._create_dummy_tensor(
|
|
@@ -152,26 +146,45 @@ class CompilationManager:
|
|
|
152
146
|
request_distribution,
|
|
153
147
|
sharding=dp_sharding)
|
|
154
148
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
149
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
150
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
151
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
152
|
+
self.runner.kv_cache_config.kv_cache_groups):
|
|
153
|
+
block_tables = self.runner.block_tables_cpu[
|
|
154
|
+
kv_cache_gid][:self.runner.max_num_reqs]
|
|
155
|
+
block_tables = block_tables.reshape(-1)
|
|
156
|
+
block_tables = device_array(self.runner.mesh,
|
|
157
|
+
block_tables,
|
|
158
|
+
sharding=dp_sharding)
|
|
159
|
+
|
|
160
|
+
attention_metadata_gid = AttentionMetadata(
|
|
161
|
+
input_positions=positions,
|
|
162
|
+
block_tables=block_tables,
|
|
163
|
+
seq_lens=seq_lens,
|
|
164
|
+
query_start_loc=query_start_loc,
|
|
165
|
+
request_distribution=request_distribution,
|
|
166
|
+
)
|
|
167
|
+
if not self.runner.use_hybrid_kvcache:
|
|
168
|
+
# all layers share the same attention metadata
|
|
169
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
170
|
+
else:
|
|
171
|
+
for layer_name in kv_cache_group.layer_names:
|
|
172
|
+
attention_metadata_per_layer[
|
|
173
|
+
layer_name] = attention_metadata_gid
|
|
162
174
|
|
|
163
175
|
def model_fn_wrapper(
|
|
164
176
|
state,
|
|
165
177
|
kv_caches,
|
|
166
178
|
input_ids,
|
|
167
179
|
attention_metadata,
|
|
180
|
+
positions,
|
|
168
181
|
inputs_embeds,
|
|
169
182
|
layer_name_to_kvcache_index,
|
|
170
183
|
lora_metadata,
|
|
171
184
|
):
|
|
172
185
|
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
173
186
|
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
174
|
-
layer_name_to_kvcache_index, lora_metadata)
|
|
187
|
+
positions, layer_name_to_kvcache_index, lora_metadata)
|
|
175
188
|
self.runner.kv_caches = kv_caches
|
|
176
189
|
return hidden_states
|
|
177
190
|
|
|
@@ -179,6 +192,10 @@ class CompilationManager:
|
|
|
179
192
|
self.runner.lora_config, np.array([num_tokens],
|
|
180
193
|
dtype=np.int32)):
|
|
181
194
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
195
|
+
if self.runner.use_hybrid_kvcache:
|
|
196
|
+
attention_metadata = attention_metadata_per_layer
|
|
197
|
+
else:
|
|
198
|
+
attention_metadata = uniform_attention_metadata
|
|
182
199
|
self._run_compilation(
|
|
183
200
|
name,
|
|
184
201
|
model_fn_wrapper,
|
|
@@ -186,6 +203,7 @@ class CompilationManager:
|
|
|
186
203
|
self.runner.kv_caches,
|
|
187
204
|
input_ids,
|
|
188
205
|
attention_metadata,
|
|
206
|
+
positions,
|
|
189
207
|
inputs_embeds,
|
|
190
208
|
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
191
209
|
lora_metadata,
|
|
@@ -332,13 +350,15 @@ class CompilationManager:
|
|
|
332
350
|
index_paddings = self.runner.num_reqs_paddings
|
|
333
351
|
dp_sharding = NamedSharding(self.runner.mesh,
|
|
334
352
|
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
353
|
+
hidden_states_sharding = NamedSharding(
|
|
354
|
+
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
|
|
335
355
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
336
356
|
self._precompile_select_from_array_helper(
|
|
337
357
|
name="select all logits",
|
|
338
358
|
source_paddings=self.runner.num_tokens_paddings,
|
|
339
359
|
indices_paddings=index_paddings,
|
|
340
360
|
hidden_dim=hsize,
|
|
341
|
-
input_sharding=
|
|
361
|
+
input_sharding=hidden_states_sharding,
|
|
342
362
|
indices_sharding=dp_sharding if dp_size > 1 else None,
|
|
343
363
|
)
|
|
344
364
|
|
|
@@ -528,7 +548,9 @@ class CompilationManager:
|
|
|
528
548
|
def _precompile_eagle3_helpers(self) -> None:
|
|
529
549
|
logger.info(
|
|
530
550
|
"Compiling eagle3 jitted helpers with different input shapes.")
|
|
531
|
-
|
|
551
|
+
target_hidden_size = self.runner.model_config.get_hidden_size()
|
|
552
|
+
draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
|
|
553
|
+
)
|
|
532
554
|
dtype = self.runner.model_config.dtype
|
|
533
555
|
|
|
534
556
|
num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
|
|
@@ -575,7 +597,7 @@ class CompilationManager:
|
|
|
575
597
|
|
|
576
598
|
for num_logits in self.runner.num_logits_paddings:
|
|
577
599
|
hidden_states = self._create_dummy_tensor(
|
|
578
|
-
(num_logits,
|
|
600
|
+
(num_logits, draft_hidden_size), jnp.bfloat16)
|
|
579
601
|
self._run_compilation(
|
|
580
602
|
"eagle3_get_draft_token_ids",
|
|
581
603
|
self.runner.drafter._get_draft_token_ids,
|
|
@@ -586,8 +608,8 @@ class CompilationManager:
|
|
|
586
608
|
input_ids_loop = self._create_dummy_tensor(
|
|
587
609
|
(self.runner.max_num_reqs, ), jnp.int32,
|
|
588
610
|
NamedSharding(self.runner.mesh, PartitionSpec()))
|
|
589
|
-
|
|
590
|
-
(self.runner.max_num_reqs,
|
|
611
|
+
draft_hidden_state_loop = self._create_dummy_tensor(
|
|
612
|
+
(self.runner.max_num_reqs, draft_hidden_size), dtype,
|
|
591
613
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
592
614
|
next_token_ids = self._create_dummy_tensor(
|
|
593
615
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
@@ -595,9 +617,12 @@ class CompilationManager:
|
|
|
595
617
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
596
618
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
597
619
|
aux_hidden_states = [
|
|
598
|
-
self._create_dummy_tensor((num_tokens,
|
|
599
|
-
|
|
600
|
-
self._create_dummy_tensor((num_tokens,
|
|
620
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
621
|
+
dtype),
|
|
622
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
623
|
+
dtype),
|
|
624
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
625
|
+
dtype),
|
|
601
626
|
]
|
|
602
627
|
|
|
603
628
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
@@ -628,15 +653,15 @@ class CompilationManager:
|
|
|
628
653
|
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
629
654
|
aux_hidden_states = [
|
|
630
655
|
self._create_dummy_tensor(
|
|
631
|
-
(num_tokens,
|
|
656
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
632
657
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
633
658
|
None))),
|
|
634
659
|
self._create_dummy_tensor(
|
|
635
|
-
(num_tokens,
|
|
660
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
636
661
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
637
662
|
None))),
|
|
638
663
|
self._create_dummy_tensor(
|
|
639
|
-
(num_tokens,
|
|
664
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
640
665
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
641
666
|
None))),
|
|
642
667
|
]
|
|
@@ -668,17 +693,17 @@ class CompilationManager:
|
|
|
668
693
|
state,
|
|
669
694
|
kv_caches,
|
|
670
695
|
input_ids,
|
|
671
|
-
|
|
696
|
+
draft_hidden_states,
|
|
672
697
|
attention_metadata,
|
|
673
698
|
):
|
|
674
699
|
kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
|
|
675
|
-
state, kv_caches, input_ids,
|
|
700
|
+
state, kv_caches, input_ids, draft_hidden_states,
|
|
676
701
|
attention_metadata)
|
|
677
702
|
self.runner.kv_caches = kv_caches
|
|
678
703
|
return hidden_states
|
|
679
704
|
|
|
680
|
-
|
|
681
|
-
(num_tokens,
|
|
705
|
+
draft_hidden_states = self._create_dummy_tensor(
|
|
706
|
+
(num_tokens, draft_hidden_size), dtype,
|
|
682
707
|
NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
|
|
683
708
|
input_ids = self._create_dummy_tensor(
|
|
684
709
|
(num_tokens, ), jnp.int32,
|
|
@@ -689,7 +714,7 @@ class CompilationManager:
|
|
|
689
714
|
self.runner.drafter.state,
|
|
690
715
|
self.runner.kv_caches,
|
|
691
716
|
input_ids,
|
|
692
|
-
|
|
717
|
+
draft_hidden_states,
|
|
693
718
|
attention_metadata,
|
|
694
719
|
num_tokens=num_tokens,
|
|
695
720
|
)
|
|
@@ -721,13 +746,13 @@ class CompilationManager:
|
|
|
721
746
|
self.runner.drafter.state,
|
|
722
747
|
self.runner.kv_caches,
|
|
723
748
|
input_ids_loop,
|
|
724
|
-
|
|
749
|
+
draft_hidden_state_loop,
|
|
725
750
|
attention_metadata,
|
|
726
751
|
num_tokens=num_tokens,
|
|
727
752
|
)
|
|
728
753
|
|
|
729
754
|
hidden_states = self._create_dummy_tensor(
|
|
730
|
-
(num_tokens,
|
|
755
|
+
(num_tokens, draft_hidden_size), jnp.bfloat16,
|
|
731
756
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
732
757
|
|
|
733
758
|
self._run_compilation(
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import functools
|
|
2
|
-
import math
|
|
3
2
|
from typing import TYPE_CHECKING, Dict, List
|
|
4
3
|
|
|
5
4
|
import jax
|
|
6
5
|
import jax.numpy as jnp
|
|
6
|
+
import numpy as np
|
|
7
7
|
import vllm.envs as envs
|
|
8
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
9
|
from torchax.ops.mappings import t2j_dtype
|
|
10
|
-
from vllm.attention import Attention
|
|
11
10
|
from vllm.attention.backends.abstract import AttentionType
|
|
11
|
+
from vllm.attention.layer import Attention
|
|
12
12
|
from vllm.config import get_layers_from_vllm_config
|
|
13
|
+
from vllm.utils.math_utils import cdiv
|
|
13
14
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
14
15
|
KVCacheSpec, MLAAttentionSpec,
|
|
15
16
|
SlidingWindowSpec)
|
|
@@ -175,6 +176,11 @@ class KVCacheManager:
|
|
|
175
176
|
)
|
|
176
177
|
self.runner.input_batch = new_input_batch
|
|
177
178
|
self.runner.persistent_batch_manager.input_batch = new_input_batch
|
|
179
|
+
self.runner.block_tables_cpu = [
|
|
180
|
+
np.zeros((self.runner.max_num_reqs,
|
|
181
|
+
cdiv(self.runner.max_model_len, block_size)),
|
|
182
|
+
dtype=np.int32) for block_size in block_sizes
|
|
183
|
+
]
|
|
178
184
|
|
|
179
185
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
180
186
|
self.maybe_reinitialize_input_batch(kv_cache_config)
|
|
@@ -190,7 +196,7 @@ class KVCacheManager:
|
|
|
190
196
|
num_blocks = kv_cache_tensor.size // page_size_bytes
|
|
191
197
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
192
198
|
# num_blocks must be a multiple of dp_size
|
|
193
|
-
num_blocks =
|
|
199
|
+
num_blocks = (num_blocks // dp_size) * dp_size
|
|
194
200
|
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
195
201
|
kv_cache = create_kv_caches(
|
|
196
202
|
num_blocks=num_blocks,
|
|
@@ -61,11 +61,10 @@ class StructuredDecodingManager:
|
|
|
61
61
|
self.runner.require_structured_out_cpu.fill(0)
|
|
62
62
|
|
|
63
63
|
sorted_struct_requests = sorted(
|
|
64
|
-
grammar_output.structured_output_request_ids
|
|
65
|
-
key=lambda item: item[1])
|
|
64
|
+
grammar_output.structured_output_request_ids)
|
|
66
65
|
|
|
67
66
|
cumulative_mask_idx = 0
|
|
68
|
-
for req_id
|
|
67
|
+
for req_id in sorted_struct_requests:
|
|
69
68
|
if req_id not in self.runner.input_batch.req_id_to_index:
|
|
70
69
|
continue
|
|
71
70
|
batch_index = self.runner.input_batch.req_id_to_index[req_id]
|