tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (49) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tpu_inference/__init__.py +22 -3
  5. tpu_inference/core/disagg_utils.py +6 -8
  6. tpu_inference/distributed/tpu_connector.py +2 -3
  7. tpu_inference/distributed/utils.py +3 -2
  8. tpu_inference/envs.py +1 -1
  9. tpu_inference/executors/ray_distributed_executor.py +27 -11
  10. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  11. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  12. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
  13. tpu_inference/layers/common/attention_interface.py +7 -1
  14. tpu_inference/layers/common/sharding.py +2 -1
  15. tpu_inference/layers/vllm/fused_moe.py +74 -25
  16. tpu_inference/layers/vllm/quantization/common.py +6 -1
  17. tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
  18. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  19. tpu_inference/layers/vllm/sharding.py +2 -2
  20. tpu_inference/lora/torch_punica_tpu.py +1 -2
  21. tpu_inference/models/common/model_loader.py +43 -11
  22. tpu_inference/models/jax/llama3.py +2 -1
  23. tpu_inference/models/jax/llama_eagle3.py +8 -5
  24. tpu_inference/models/jax/llama_guard_4.py +361 -0
  25. tpu_inference/models/jax/qwen2.py +2 -1
  26. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  27. tpu_inference/models/jax/qwen3.py +2 -1
  28. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  29. tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
  30. tpu_inference/platforms/tpu_platform.py +15 -2
  31. tpu_inference/runner/compilation_manager.py +58 -33
  32. tpu_inference/runner/kv_cache_manager.py +9 -3
  33. tpu_inference/runner/structured_decoding_manager.py +2 -3
  34. tpu_inference/runner/tpu_runner.py +203 -102
  35. tpu_inference/spec_decode/jax/eagle3.py +19 -2
  36. tpu_inference/tpu_info.py +4 -3
  37. tpu_inference/utils.py +5 -4
  38. tpu_inference/worker/tpu_worker.py +160 -23
  39. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
  40. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
  41. tpu_inference/mock/__init__.py +0 -0
  42. tpu_inference/mock/vllm_config_utils.py +0 -28
  43. tpu_inference/mock/vllm_envs.py +0 -1219
  44. tpu_inference/mock/vllm_logger.py +0 -212
  45. tpu_inference/mock/vllm_logging_utils.py +0 -15
  46. tpu_inference/models/jax/phi3.py +0 -376
  47. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
  48. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
  49. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import time
3
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
3
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
@@ -135,12 +135,6 @@ class CompilationManager:
135
135
  ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
136
136
 
137
137
  # Keep existing pattern for complex array operations
138
- block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
139
- block_tables = block_tables.reshape(-1)
140
- block_tables = device_array(self.runner.mesh,
141
- block_tables,
142
- sharding=dp_sharding)
143
-
144
138
  seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
145
139
  jnp.int32, dp_sharding)
146
140
  query_start_loc = self._create_dummy_tensor(
@@ -152,26 +146,45 @@ class CompilationManager:
152
146
  request_distribution,
153
147
  sharding=dp_sharding)
154
148
 
155
- attention_metadata = AttentionMetadata(
156
- input_positions=positions,
157
- block_tables=block_tables,
158
- seq_lens=seq_lens,
159
- query_start_loc=query_start_loc,
160
- request_distribution=request_distribution,
161
- )
149
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
150
+ uniform_attention_metadata: AttentionMetadata = None
151
+ for kv_cache_gid, kv_cache_group in enumerate(
152
+ self.runner.kv_cache_config.kv_cache_groups):
153
+ block_tables = self.runner.block_tables_cpu[
154
+ kv_cache_gid][:self.runner.max_num_reqs]
155
+ block_tables = block_tables.reshape(-1)
156
+ block_tables = device_array(self.runner.mesh,
157
+ block_tables,
158
+ sharding=dp_sharding)
159
+
160
+ attention_metadata_gid = AttentionMetadata(
161
+ input_positions=positions,
162
+ block_tables=block_tables,
163
+ seq_lens=seq_lens,
164
+ query_start_loc=query_start_loc,
165
+ request_distribution=request_distribution,
166
+ )
167
+ if not self.runner.use_hybrid_kvcache:
168
+ # all layers share the same attention metadata
169
+ uniform_attention_metadata = attention_metadata_gid
170
+ else:
171
+ for layer_name in kv_cache_group.layer_names:
172
+ attention_metadata_per_layer[
173
+ layer_name] = attention_metadata_gid
162
174
 
163
175
  def model_fn_wrapper(
164
176
  state,
165
177
  kv_caches,
166
178
  input_ids,
167
179
  attention_metadata,
180
+ positions,
168
181
  inputs_embeds,
169
182
  layer_name_to_kvcache_index,
170
183
  lora_metadata,
171
184
  ):
172
185
  kv_caches, hidden_states, _ = self.runner.model_fn(
173
186
  state, kv_caches, input_ids, attention_metadata, inputs_embeds,
174
- layer_name_to_kvcache_index, lora_metadata)
187
+ positions, layer_name_to_kvcache_index, lora_metadata)
175
188
  self.runner.kv_caches = kv_caches
176
189
  return hidden_states
177
190
 
@@ -179,6 +192,10 @@ class CompilationManager:
179
192
  self.runner.lora_config, np.array([num_tokens],
180
193
  dtype=np.int32)):
181
194
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
195
+ if self.runner.use_hybrid_kvcache:
196
+ attention_metadata = attention_metadata_per_layer
197
+ else:
198
+ attention_metadata = uniform_attention_metadata
182
199
  self._run_compilation(
183
200
  name,
184
201
  model_fn_wrapper,
@@ -186,6 +203,7 @@ class CompilationManager:
186
203
  self.runner.kv_caches,
187
204
  input_ids,
188
205
  attention_metadata,
206
+ positions,
189
207
  inputs_embeds,
190
208
  tuple(self.runner.layer_name_to_kvcache_index.items()),
191
209
  lora_metadata,
@@ -332,13 +350,15 @@ class CompilationManager:
332
350
  index_paddings = self.runner.num_reqs_paddings
333
351
  dp_sharding = NamedSharding(self.runner.mesh,
334
352
  PartitionSpec(ShardingAxisName.ATTN_DATA))
353
+ hidden_states_sharding = NamedSharding(
354
+ self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
335
355
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
336
356
  self._precompile_select_from_array_helper(
337
357
  name="select all logits",
338
358
  source_paddings=self.runner.num_tokens_paddings,
339
359
  indices_paddings=index_paddings,
340
360
  hidden_dim=hsize,
341
- input_sharding=dp_sharding,
361
+ input_sharding=hidden_states_sharding,
342
362
  indices_sharding=dp_sharding if dp_size > 1 else None,
343
363
  )
344
364
 
@@ -528,7 +548,9 @@ class CompilationManager:
528
548
  def _precompile_eagle3_helpers(self) -> None:
529
549
  logger.info(
530
550
  "Compiling eagle3 jitted helpers with different input shapes.")
531
- hidden_size = self.runner.model_config.get_hidden_size()
551
+ target_hidden_size = self.runner.model_config.get_hidden_size()
552
+ draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
553
+ )
532
554
  dtype = self.runner.model_config.dtype
533
555
 
534
556
  num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
@@ -575,7 +597,7 @@ class CompilationManager:
575
597
 
576
598
  for num_logits in self.runner.num_logits_paddings:
577
599
  hidden_states = self._create_dummy_tensor(
578
- (num_logits, hidden_size), jnp.bfloat16)
600
+ (num_logits, draft_hidden_size), jnp.bfloat16)
579
601
  self._run_compilation(
580
602
  "eagle3_get_draft_token_ids",
581
603
  self.runner.drafter._get_draft_token_ids,
@@ -586,8 +608,8 @@ class CompilationManager:
586
608
  input_ids_loop = self._create_dummy_tensor(
587
609
  (self.runner.max_num_reqs, ), jnp.int32,
588
610
  NamedSharding(self.runner.mesh, PartitionSpec()))
589
- target_hidden_state_loop = self._create_dummy_tensor(
590
- (self.runner.max_num_reqs, hidden_size), dtype,
611
+ draft_hidden_state_loop = self._create_dummy_tensor(
612
+ (self.runner.max_num_reqs, draft_hidden_size), dtype,
591
613
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
592
614
  next_token_ids = self._create_dummy_tensor(
593
615
  (self.runner.max_num_reqs, ), jnp.int32)
@@ -595,9 +617,12 @@ class CompilationManager:
595
617
  (self.runner.max_num_reqs, ), jnp.int32)
596
618
  for num_tokens in self.runner.num_tokens_paddings:
597
619
  aux_hidden_states = [
598
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
599
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
600
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
620
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
621
+ dtype),
622
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
623
+ dtype),
624
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
625
+ dtype),
601
626
  ]
602
627
 
603
628
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
@@ -628,15 +653,15 @@ class CompilationManager:
628
653
  input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
629
654
  aux_hidden_states = [
630
655
  self._create_dummy_tensor(
631
- (num_tokens, hidden_size), jnp.bfloat16,
656
+ (num_tokens, target_hidden_size), jnp.bfloat16,
632
657
  NamedSharding(self.runner.mesh, PartitionSpec(None,
633
658
  None))),
634
659
  self._create_dummy_tensor(
635
- (num_tokens, hidden_size), jnp.bfloat16,
660
+ (num_tokens, target_hidden_size), jnp.bfloat16,
636
661
  NamedSharding(self.runner.mesh, PartitionSpec(None,
637
662
  None))),
638
663
  self._create_dummy_tensor(
639
- (num_tokens, hidden_size), jnp.bfloat16,
664
+ (num_tokens, target_hidden_size), jnp.bfloat16,
640
665
  NamedSharding(self.runner.mesh, PartitionSpec(None,
641
666
  None))),
642
667
  ]
@@ -668,17 +693,17 @@ class CompilationManager:
668
693
  state,
669
694
  kv_caches,
670
695
  input_ids,
671
- target_hidden_states,
696
+ draft_hidden_states,
672
697
  attention_metadata,
673
698
  ):
674
699
  kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
675
- state, kv_caches, input_ids, target_hidden_states,
700
+ state, kv_caches, input_ids, draft_hidden_states,
676
701
  attention_metadata)
677
702
  self.runner.kv_caches = kv_caches
678
703
  return hidden_states
679
704
 
680
- target_hidden_states = self._create_dummy_tensor(
681
- (num_tokens, hidden_size), dtype,
705
+ draft_hidden_states = self._create_dummy_tensor(
706
+ (num_tokens, draft_hidden_size), dtype,
682
707
  NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
683
708
  input_ids = self._create_dummy_tensor(
684
709
  (num_tokens, ), jnp.int32,
@@ -689,7 +714,7 @@ class CompilationManager:
689
714
  self.runner.drafter.state,
690
715
  self.runner.kv_caches,
691
716
  input_ids,
692
- target_hidden_states,
717
+ draft_hidden_states,
693
718
  attention_metadata,
694
719
  num_tokens=num_tokens,
695
720
  )
@@ -721,13 +746,13 @@ class CompilationManager:
721
746
  self.runner.drafter.state,
722
747
  self.runner.kv_caches,
723
748
  input_ids_loop,
724
- target_hidden_state_loop,
749
+ draft_hidden_state_loop,
725
750
  attention_metadata,
726
751
  num_tokens=num_tokens,
727
752
  )
728
753
 
729
754
  hidden_states = self._create_dummy_tensor(
730
- (num_tokens, hidden_size), jnp.bfloat16,
755
+ (num_tokens, draft_hidden_size), jnp.bfloat16,
731
756
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
732
757
 
733
758
  self._run_compilation(
@@ -1,15 +1,16 @@
1
1
  import functools
2
- import math
3
2
  from typing import TYPE_CHECKING, Dict, List
4
3
 
5
4
  import jax
6
5
  import jax.numpy as jnp
6
+ import numpy as np
7
7
  import vllm.envs as envs
8
8
  from jax.sharding import NamedSharding, PartitionSpec
9
9
  from torchax.ops.mappings import t2j_dtype
10
- from vllm.attention import Attention
11
10
  from vllm.attention.backends.abstract import AttentionType
11
+ from vllm.attention.layer import Attention
12
12
  from vllm.config import get_layers_from_vllm_config
13
+ from vllm.utils.math_utils import cdiv
13
14
  from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
14
15
  KVCacheSpec, MLAAttentionSpec,
15
16
  SlidingWindowSpec)
@@ -175,6 +176,11 @@ class KVCacheManager:
175
176
  )
176
177
  self.runner.input_batch = new_input_batch
177
178
  self.runner.persistent_batch_manager.input_batch = new_input_batch
179
+ self.runner.block_tables_cpu = [
180
+ np.zeros((self.runner.max_num_reqs,
181
+ cdiv(self.runner.max_model_len, block_size)),
182
+ dtype=np.int32) for block_size in block_sizes
183
+ ]
178
184
 
179
185
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
180
186
  self.maybe_reinitialize_input_batch(kv_cache_config)
@@ -190,7 +196,7 @@ class KVCacheManager:
190
196
  num_blocks = kv_cache_tensor.size // page_size_bytes
191
197
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
192
198
  # num_blocks must be a multiple of dp_size
193
- num_blocks = math.ceil(num_blocks / dp_size) * dp_size
199
+ num_blocks = (num_blocks // dp_size) * dp_size
194
200
  # NOTE: we'll multiply the num_kv_heads by 2 in the function
195
201
  kv_cache = create_kv_caches(
196
202
  num_blocks=num_blocks,
@@ -61,11 +61,10 @@ class StructuredDecodingManager:
61
61
  self.runner.require_structured_out_cpu.fill(0)
62
62
 
63
63
  sorted_struct_requests = sorted(
64
- grammar_output.structured_output_request_ids.items(),
65
- key=lambda item: item[1])
64
+ grammar_output.structured_output_request_ids)
66
65
 
67
66
  cumulative_mask_idx = 0
68
- for req_id, _ in sorted_struct_requests:
67
+ for req_id in sorted_struct_requests:
69
68
  if req_id not in self.runner.input_batch.req_id_to_index:
70
69
  continue
71
70
  batch_index = self.runner.input_batch.req_id_to_index[req_id]