tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,16 @@
1
1
  import functools
2
- import math
3
2
  from typing import TYPE_CHECKING, Dict, List
4
3
 
5
4
  import jax
6
5
  import jax.numpy as jnp
6
+ import numpy as np
7
7
  import vllm.envs as envs
8
8
  from jax.sharding import NamedSharding, PartitionSpec
9
9
  from torchax.ops.mappings import t2j_dtype
10
- from vllm.attention import Attention
11
10
  from vllm.attention.backends.abstract import AttentionType
11
+ from vllm.attention.layer import Attention
12
12
  from vllm.config import get_layers_from_vllm_config
13
+ from vllm.utils.math_utils import cdiv
13
14
  from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
14
15
  KVCacheSpec, MLAAttentionSpec,
15
16
  SlidingWindowSpec)
@@ -38,20 +39,30 @@ class KVCacheManager:
38
39
  # means this layer will perform attention using the keys and values
39
40
  # from the KV cache of `shared_kv_cache_layers[layer_name]`.
40
41
  self.shared_kv_cache_layers: dict[str, str] = {}
42
+ self.use_mla = self.runner.model_config.use_mla
41
43
 
42
44
  def get_kv_cache_spec(self):
43
45
  # TODO(xiang): this hack tricks engine core to init successfully
44
46
  block_size = self.runner.cache_config.block_size
45
- use_mla = self.runner.model_config.use_mla
46
47
  kv_cache_spec: dict[str, KVCacheSpec] = {}
47
48
 
48
49
  # If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
49
50
  # attention into compilation config.
50
51
  # Use FullAttentionSpec for each layer
51
52
  # TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
53
+ model_config = self.runner.model_config
54
+ if self.use_mla:
55
+ # Individually pad the RopE and latents
56
+ qk_rope_head_dim = getattr(model_config.hf_text_config,
57
+ "qk_rope_head_dim", 0)
58
+ padded_kv_lora_rank = common_utils.align_to(
59
+ model_config.hf_text_config.kv_lora_rank, 128)
60
+ padded_qk_rope_head_dim = common_utils.align_to(
61
+ qk_rope_head_dim, 128)
62
+ mla_head_size = padded_kv_lora_rank + padded_qk_rope_head_dim
63
+
52
64
  if len(self.runner.vllm_config.compilation_config.
53
65
  static_forward_context) == 0:
54
- model_config = self.runner.model_config
55
66
  parallel_config = self.runner.parallel_config
56
67
  # Pad num_kv_heads to multiple of TP size.
57
68
  num_kv_heads = common_utils.get_padded_num_heads(
@@ -60,11 +71,11 @@ class KVCacheManager:
60
71
  head_size = common_utils.get_padded_head_dim(
61
72
  model_config.get_head_size())
62
73
  for i in range(model_config.get_num_layers(parallel_config)):
63
- if use_mla:
74
+ if self.use_mla:
64
75
  kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
65
76
  block_size=block_size,
66
- num_kv_heads=num_kv_heads,
67
- head_size=head_size,
77
+ num_kv_heads=1,
78
+ head_size=mla_head_size,
68
79
  dtype=self.runner.kv_cache_dtype,
69
80
  cache_dtype_str=self.runner.vllm_config.cache_config.
70
81
  cache_dtype)
@@ -82,14 +93,13 @@ class KVCacheManager:
82
93
  self.runner.mesh.shape["model"])
83
94
  head_size = common_utils.get_padded_head_dim(
84
95
  hf_config.hidden_size // hf_config.num_attention_heads)
85
-
86
96
  # Eagle3 has only 1 layer
87
97
  for i in range(1):
88
- if use_mla:
89
- kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
98
+ if self.use_mla:
99
+ kv_cache_spec[f"draft_layer.{i}"] = MLAAttentionSpec(
90
100
  block_size=block_size,
91
- num_kv_heads=num_kv_heads,
92
- head_size=head_size,
101
+ num_kv_heads=1,
102
+ head_size=mla_head_size,
93
103
  dtype=self.runner.kv_cache_dtype,
94
104
  cache_dtype_str=self.runner.vllm_config.
95
105
  cache_config.cache_dtype)
@@ -103,6 +113,7 @@ class KVCacheManager:
103
113
  # Else propagate attention modules from compilation config.
104
114
  layers = get_layers_from_vllm_config(self.runner.vllm_config,
105
115
  Attention)
116
+ logger.warning(f"Compilation num_layers = {len(layers.items())}")
106
117
  for layer_name, attn_module in layers.items():
107
118
  if (kv_tgt_layer :=
108
119
  attn_module.kv_sharing_target_layer_name) is not None:
@@ -126,11 +137,11 @@ class KVCacheManager:
126
137
  attn_module.head_size),
127
138
  dtype=self.runner.kv_cache_dtype,
128
139
  sliding_window=attn_module.sliding_window)
129
- elif use_mla:
130
- kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
140
+ elif self.use_mla:
141
+ kv_cache_spec[layer_name] = MLAAttentionSpec(
131
142
  block_size=block_size,
132
- num_kv_heads=attn_module.num_kv_heads,
133
- head_size=attn_module.head_size,
143
+ num_kv_heads=1,
144
+ head_size=mla_head_size,
134
145
  dtype=self.runner.kv_cache_dtype,
135
146
  cache_dtype_str=self.runner.vllm_config.
136
147
  cache_config.cache_dtype)
@@ -175,6 +186,11 @@ class KVCacheManager:
175
186
  )
176
187
  self.runner.input_batch = new_input_batch
177
188
  self.runner.persistent_batch_manager.input_batch = new_input_batch
189
+ self.runner.block_tables_cpu = [
190
+ np.zeros((self.runner.max_num_reqs,
191
+ cdiv(self.runner.max_model_len, block_size)),
192
+ dtype=np.int32) for block_size in block_sizes
193
+ ]
178
194
 
179
195
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
180
196
  self.maybe_reinitialize_input_batch(kv_cache_config)
@@ -190,16 +206,22 @@ class KVCacheManager:
190
206
  num_blocks = kv_cache_tensor.size // page_size_bytes
191
207
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
192
208
  # num_blocks must be a multiple of dp_size
193
- num_blocks = math.ceil(num_blocks / dp_size) * dp_size
209
+ num_blocks = (num_blocks // dp_size) * dp_size
194
210
  # NOTE: we'll multiply the num_kv_heads by 2 in the function
211
+ if self.use_mla:
212
+ head_size = self.runner.model_config.hf_config.kv_lora_rank + \
213
+ self.runner.model_config.hf_config.qk_rope_head_dim
214
+ else:
215
+ head_size = representative_spec.head_size
195
216
  kv_cache = create_kv_caches(
196
217
  num_blocks=num_blocks,
197
218
  block_size=representative_spec.block_size,
198
219
  num_kv_heads=representative_spec.num_kv_heads,
199
- head_size=representative_spec.head_size,
220
+ head_size=head_size,
200
221
  mesh=self.runner.mesh,
201
222
  layer_names=[f'kv_cache_tensor.{i}'],
202
223
  cache_dtype=t2j_dtype(representative_spec.dtype),
224
+ use_mla=self.use_mla,
203
225
  )[0]
204
226
  kv_caches.append(kv_cache)
205
227
  num_blocks_list.append(num_blocks)
@@ -283,13 +305,8 @@ class KVCacheManager:
283
305
 
284
306
  def _update_layer(cache, slices):
285
307
  """The function to apply to each layer's cache and slices."""
286
- reshaped_slices = slices.reshape(-1, 1, block_size,
287
- *slices.shape[1:])
288
- for (i, block_idx) in enumerate(block_numbers):
289
- cache = jax.lax.dynamic_update_slice_in_dim(cache,
290
- reshaped_slices[i],
291
- block_idx,
292
- axis=0)
308
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
309
+ cache.at[block_numbers].set(reshaped_slices)
293
310
  return cache
294
311
 
295
312
  return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
@@ -342,16 +359,12 @@ class KVCacheManager:
342
359
  """
343
360
  if block_ids == list(range(block_ids[0],
344
361
  block_ids[0] + len(block_ids))):
345
- with runner_utils.LatencyTracker(
346
- "BatchedGatherKVSlices-for-blocks"):
347
- batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
348
- self.runner.kv_caches, block_ids[0], len(block_ids))
362
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
363
+ self.runner.kv_caches, block_ids[0], len(block_ids))
349
364
 
350
365
  else:
351
- with runner_utils.LatencyTracker(
352
- "BatchedGatherKVSlices-for-blocks"):
353
- batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
354
- self.runner.kv_caches, jnp.array(block_ids))
366
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
367
+ self.runner.kv_caches, jnp.array(block_ids))
355
368
  return batched_kv_cache_per_layer
356
369
 
357
370
  def transfer_kv_cache(self,
@@ -440,6 +453,7 @@ class KVCacheManager:
440
453
  kv_cache_slices,
441
454
  start_block,
442
455
  )
456
+ jax.block_until_ready(self.runner.kv_caches)
443
457
  else:
444
458
  with runner_utils.LatencyTracker(
445
459
  f"JittedInsertKVCache-b{len(block_numbers)}"):
@@ -451,6 +465,7 @@ class KVCacheManager:
451
465
  kv_cache_slices,
452
466
  jnp.array(block_numbers),
453
467
  )
468
+ jax.block_until_ready(self.runner.kv_caches)
454
469
 
455
470
  logger.debug(
456
471
  f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
@@ -14,12 +14,13 @@ class PersistentBatchManager:
14
14
  def __init__(self, requests: Dict[str, CachedRequestState],
15
15
  input_batch: InputBatch, encoder_cache: Dict[str,
16
16
  'jax.Array'],
17
- uses_mrope: bool, model_config):
17
+ uses_mrope: bool, model_config, is_last_rank: bool):
18
18
  self.requests = requests
19
19
  self.input_batch = input_batch
20
20
  self.encoder_cache = encoder_cache
21
21
  self.uses_mrope = uses_mrope
22
22
  self.model_config = model_config
23
+ self.is_last_rank = is_last_rank
23
24
 
24
25
  def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
25
26
  """ Reorder the sheduled requests to RPA kernel friendly distribution
@@ -179,9 +180,35 @@ class PersistentBatchManager:
179
180
  num_computed_tokens = req_data.num_computed_tokens[i]
180
181
  new_block_ids = req_data.new_block_ids[i]
181
182
  resumed_from_preemption = req_data.resumed_from_preemption[i]
183
+ num_output_tokens = req_data.num_output_tokens[i]
182
184
 
183
185
  # Update the cached states.
184
186
  req_state.num_computed_tokens = num_computed_tokens
187
+ req_index = self.input_batch.req_id_to_index.get(req_id)
188
+
189
+ if not self.is_last_rank:
190
+ # When using PP, the scheduler sends the sampled tokens back,
191
+ # because there's no direct communication between the first-
192
+ # stage worker and the last-stage worker.
193
+ new_token_ids = req_data.new_token_ids[i]
194
+ # Add the sampled token(s) from the previous step (if any).
195
+ # This doesn't include "unverified" tokens like spec tokens.
196
+ num_new_tokens = (num_computed_tokens + len(new_token_ids) -
197
+ req_state.num_tokens)
198
+ if num_new_tokens == 1:
199
+ req_state.output_token_ids.append(new_token_ids[-1])
200
+ elif num_new_tokens > 0:
201
+ req_state.output_token_ids.extend(
202
+ new_token_ids[-num_new_tokens:])
203
+ elif num_output_tokens < len(req_state.output_token_ids):
204
+ del req_state.output_token_ids[num_output_tokens:]
205
+ if req_index is not None:
206
+ end_idx = (self.input_batch.num_prompt_tokens[req_index] +
207
+ num_output_tokens)
208
+ self.input_batch.num_tokens[req_index] = end_idx
209
+ self.input_batch.num_tokens_no_spec[req_index] = end_idx
210
+
211
+ # Update the block IDs.
185
212
  if not resumed_from_preemption:
186
213
  if new_block_ids is not None:
187
214
  # Append the new blocks to the existing block IDs.
@@ -194,7 +221,6 @@ class PersistentBatchManager:
194
221
  # Replace the existing block IDs with the new ones.
195
222
  req_state.block_ids = new_block_ids
196
223
 
197
- req_index = self.input_batch.req_id_to_index.get(req_id)
198
224
  if req_index is None:
199
225
  # The request is not in the persistent batch.
200
226
  # The request was either preempted and resumed later, or was not
@@ -209,6 +235,18 @@ class PersistentBatchManager:
209
235
  self.input_batch.block_table.append_row(
210
236
  new_block_ids, req_index)
211
237
 
238
+ # For the last rank, we don't need to update the token_ids_cpu
239
+ # because the sampled tokens are already cached.
240
+ if not self.is_last_rank:
241
+ start_token_index = num_computed_tokens
242
+ end_token_index = num_computed_tokens + len(new_token_ids)
243
+ self.input_batch.token_ids_cpu[
244
+ req_index,
245
+ start_token_index:end_token_index] = new_token_ids
246
+ self.input_batch.num_tokens_no_spec[
247
+ req_index] = end_token_index
248
+ self.input_batch.num_tokens[req_index] = end_token_index
249
+
212
250
  # Add spec_token_ids to token_ids_cpu.
213
251
  spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
214
252
  req_id, ())
@@ -61,11 +61,10 @@ class StructuredDecodingManager:
61
61
  self.runner.require_structured_out_cpu.fill(0)
62
62
 
63
63
  sorted_struct_requests = sorted(
64
- grammar_output.structured_output_request_ids.items(),
65
- key=lambda item: item[1])
64
+ grammar_output.structured_output_request_ids)
66
65
 
67
66
  cumulative_mask_idx = 0
68
- for req_id, _ in sorted_struct_requests:
67
+ for req_id in sorted_struct_requests:
69
68
  if req_id not in self.runner.input_batch.req_id_to_index:
70
69
  continue
71
70
  batch_index = self.runner.input_batch.req_id_to_index[req_id]