tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,15 @@
1
1
  import functools
2
+ import math
2
3
  from typing import TYPE_CHECKING, Dict, List
3
4
 
4
5
  import jax
5
6
  import jax.numpy as jnp
6
- import numpy as np
7
7
  import vllm.envs as envs
8
8
  from jax.sharding import NamedSharding, PartitionSpec
9
9
  from torchax.ops.mappings import t2j_dtype
10
+ from vllm.attention import Attention
10
11
  from vllm.attention.backends.abstract import AttentionType
11
- from vllm.attention.layer import Attention
12
12
  from vllm.config import get_layers_from_vllm_config
13
- from vllm.utils.math_utils import cdiv
14
13
  from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
15
14
  KVCacheSpec, MLAAttentionSpec,
16
15
  SlidingWindowSpec)
@@ -176,11 +175,6 @@ class KVCacheManager:
176
175
  )
177
176
  self.runner.input_batch = new_input_batch
178
177
  self.runner.persistent_batch_manager.input_batch = new_input_batch
179
- self.runner.block_tables_cpu = [
180
- np.zeros((self.runner.max_num_reqs,
181
- cdiv(self.runner.max_model_len, block_size)),
182
- dtype=np.int32) for block_size in block_sizes
183
- ]
184
178
 
185
179
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
186
180
  self.maybe_reinitialize_input_batch(kv_cache_config)
@@ -196,7 +190,7 @@ class KVCacheManager:
196
190
  num_blocks = kv_cache_tensor.size // page_size_bytes
197
191
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
198
192
  # num_blocks must be a multiple of dp_size
199
- num_blocks = (num_blocks // dp_size) * dp_size
193
+ num_blocks = math.ceil(num_blocks / dp_size) * dp_size
200
194
  # NOTE: we'll multiply the num_kv_heads by 2 in the function
201
195
  kv_cache = create_kv_caches(
202
196
  num_blocks=num_blocks,
@@ -289,8 +283,13 @@ class KVCacheManager:
289
283
 
290
284
  def _update_layer(cache, slices):
291
285
  """The function to apply to each layer's cache and slices."""
292
- reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
293
- cache.at[block_numbers].set(reshaped_slices)
286
+ reshaped_slices = slices.reshape(-1, 1, block_size,
287
+ *slices.shape[1:])
288
+ for (i, block_idx) in enumerate(block_numbers):
289
+ cache = jax.lax.dynamic_update_slice_in_dim(cache,
290
+ reshaped_slices[i],
291
+ block_idx,
292
+ axis=0)
294
293
  return cache
295
294
 
296
295
  return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
@@ -343,12 +342,16 @@ class KVCacheManager:
343
342
  """
344
343
  if block_ids == list(range(block_ids[0],
345
344
  block_ids[0] + len(block_ids))):
346
- batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
347
- self.runner.kv_caches, block_ids[0], len(block_ids))
345
+ with runner_utils.LatencyTracker(
346
+ "BatchedGatherKVSlices-for-blocks"):
347
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
348
+ self.runner.kv_caches, block_ids[0], len(block_ids))
348
349
 
349
350
  else:
350
- batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
351
- self.runner.kv_caches, jnp.array(block_ids))
351
+ with runner_utils.LatencyTracker(
352
+ "BatchedGatherKVSlices-for-blocks"):
353
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
354
+ self.runner.kv_caches, jnp.array(block_ids))
352
355
  return batched_kv_cache_per_layer
353
356
 
354
357
  def transfer_kv_cache(self,
@@ -437,7 +440,6 @@ class KVCacheManager:
437
440
  kv_cache_slices,
438
441
  start_block,
439
442
  )
440
- jax.block_until_ready(self.runner.kv_caches)
441
443
  else:
442
444
  with runner_utils.LatencyTracker(
443
445
  f"JittedInsertKVCache-b{len(block_numbers)}"):
@@ -449,7 +451,6 @@ class KVCacheManager:
449
451
  kv_cache_slices,
450
452
  jnp.array(block_numbers),
451
453
  )
452
- jax.block_until_ready(self.runner.kv_caches)
453
454
 
454
455
  logger.debug(
455
456
  f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
@@ -14,13 +14,12 @@ class PersistentBatchManager:
14
14
  def __init__(self, requests: Dict[str, CachedRequestState],
15
15
  input_batch: InputBatch, encoder_cache: Dict[str,
16
16
  'jax.Array'],
17
- uses_mrope: bool, model_config, is_last_rank: bool):
17
+ uses_mrope: bool, model_config):
18
18
  self.requests = requests
19
19
  self.input_batch = input_batch
20
20
  self.encoder_cache = encoder_cache
21
21
  self.uses_mrope = uses_mrope
22
22
  self.model_config = model_config
23
- self.is_last_rank = is_last_rank
24
23
 
25
24
  def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
26
25
  """ Reorder the sheduled requests to RPA kernel friendly distribution
@@ -180,35 +179,9 @@ class PersistentBatchManager:
180
179
  num_computed_tokens = req_data.num_computed_tokens[i]
181
180
  new_block_ids = req_data.new_block_ids[i]
182
181
  resumed_from_preemption = req_data.resumed_from_preemption[i]
183
- num_output_tokens = req_data.num_output_tokens[i]
184
182
 
185
183
  # Update the cached states.
186
184
  req_state.num_computed_tokens = num_computed_tokens
187
- req_index = self.input_batch.req_id_to_index.get(req_id)
188
-
189
- if not self.is_last_rank:
190
- # When using PP, the scheduler sends the sampled tokens back,
191
- # because there's no direct communication between the first-
192
- # stage worker and the last-stage worker.
193
- new_token_ids = req_data.new_token_ids[i]
194
- # Add the sampled token(s) from the previous step (if any).
195
- # This doesn't include "unverified" tokens like spec tokens.
196
- num_new_tokens = (num_computed_tokens + len(new_token_ids) -
197
- req_state.num_tokens)
198
- if num_new_tokens == 1:
199
- req_state.output_token_ids.append(new_token_ids[-1])
200
- elif num_new_tokens > 0:
201
- req_state.output_token_ids.extend(
202
- new_token_ids[-num_new_tokens:])
203
- elif num_output_tokens < len(req_state.output_token_ids):
204
- del req_state.output_token_ids[num_output_tokens:]
205
- if req_index is not None:
206
- end_idx = (self.input_batch.num_prompt_tokens[req_index] +
207
- num_output_tokens)
208
- self.input_batch.num_tokens[req_index] = end_idx
209
- self.input_batch.num_tokens_no_spec[req_index] = end_idx
210
-
211
- # Update the block IDs.
212
185
  if not resumed_from_preemption:
213
186
  if new_block_ids is not None:
214
187
  # Append the new blocks to the existing block IDs.
@@ -221,6 +194,7 @@ class PersistentBatchManager:
221
194
  # Replace the existing block IDs with the new ones.
222
195
  req_state.block_ids = new_block_ids
223
196
 
197
+ req_index = self.input_batch.req_id_to_index.get(req_id)
224
198
  if req_index is None:
225
199
  # The request is not in the persistent batch.
226
200
  # The request was either preempted and resumed later, or was not
@@ -235,18 +209,6 @@ class PersistentBatchManager:
235
209
  self.input_batch.block_table.append_row(
236
210
  new_block_ids, req_index)
237
211
 
238
- # For the last rank, we don't need to update the token_ids_cpu
239
- # because the sampled tokens are already cached.
240
- if not self.is_last_rank:
241
- start_token_index = num_computed_tokens
242
- end_token_index = num_computed_tokens + len(new_token_ids)
243
- self.input_batch.token_ids_cpu[
244
- req_index,
245
- start_token_index:end_token_index] = new_token_ids
246
- self.input_batch.num_tokens_no_spec[
247
- req_index] = end_token_index
248
- self.input_batch.num_tokens[req_index] = end_token_index
249
-
250
212
  # Add spec_token_ids to token_ids_cpu.
251
213
  spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
252
214
  req_id, ())
@@ -61,10 +61,11 @@ class StructuredDecodingManager:
61
61
  self.runner.require_structured_out_cpu.fill(0)
62
62
 
63
63
  sorted_struct_requests = sorted(
64
- grammar_output.structured_output_request_ids)
64
+ grammar_output.structured_output_request_ids.items(),
65
+ key=lambda item: item[1])
65
66
 
66
67
  cumulative_mask_idx = 0
67
- for req_id in sorted_struct_requests:
68
+ for req_id, _ in sorted_struct_requests:
68
69
  if req_id not in self.runner.input_batch.req_id_to_index:
69
70
  continue
70
71
  batch_index = self.runner.input_batch.req_id_to_index[req_id]