tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -540,12 +540,16 @@ def get_vmem_estimate_bytes(
540
540
  """Returns the total vmem bytes used by the kernel."""
541
541
  m_per_device = m // tp_size
542
542
  n_per_device = n // tp_size
543
- y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype) // 8
543
+ y_vmem_bytes = (n_per_device * k * (dtypes.bit_width(y_dtype) if hasattr(
544
+ dtypes, "bit_width") else dtypes.itemsize_bits(y_dtype)) // 8)
544
545
  total_bytes = (
545
- 2 * m_per_device * k * dtypes.bit_width(x_dtype) //
546
- 8 # x_vmem_scratch_ref
546
+ 2 * m_per_device * k *
547
+ (dtypes.bit_width(x_dtype) if hasattr(dtypes, "bit_width") else
548
+ dtypes.itemsize_bits(x_dtype)) // 8 # x_vmem_scratch_ref
547
549
  + y_vmem_bytes # y_vmem_scratch_ref
548
- + 2 * m * bn * dtypes.bit_width(out_dtype) // 8 # o_vmem_scratch_ref
550
+ + 2 * m * bn *
551
+ (dtypes.bit_width(out_dtype) if hasattr(dtypes, "bit_width") else
552
+ dtypes.itemsize_bits(out_dtype)) // 8 # o_vmem_scratch_ref
549
553
  + acc_bytes # acc_vmem_scratch_ref, jnp.float32
550
554
  )
551
555
  return total_bytes
@@ -639,8 +643,10 @@ def all_gather_matmul(
639
643
  # NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
640
644
  if grid_k == 1:
641
645
  acc_shape = (8, 128)
642
- acc_bytes = acc_shape[0] * acc_shape[1] * dtypes.bit_width(
643
- jnp.float32) // 8
646
+ acc_bytes = (
647
+ acc_shape[0] *
648
+ acc_shape[1] * (dtypes.bit_width(jnp.float32) if hasattr(
649
+ dtypes, "bit_width") else dtypes.itemsize_bits(jnp.float32)) // 8)
644
650
  y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
645
651
  estimated_vmem_bytes = get_vmem_estimate_bytes(
646
652
  m,
@@ -1,6 +1,8 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
  """All-gather matmul kernel's tuned block sizes."""
3
3
 
4
+ import re
5
+
4
6
  import jax
5
7
 
6
8
  # key:
@@ -32,8 +34,11 @@ def get_tpu_version() -> int:
32
34
  return -1
33
35
  if kind.endswith(' lite'):
34
36
  kind = kind[:-len(' lite')]
35
- assert kind[:-1] == 'TPU v', kind
36
- return int(kind[-1])
37
+
38
+ # v6: "TPU v6"
39
+ # v7: "TPU7x"
40
+ assert kind[:3] == 'TPU', kind
41
+ return int(re.search(r'\d+', kind).group())
37
42
 
38
43
 
39
44
  def get_key(