tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -540,12 +540,16 @@ def get_vmem_estimate_bytes(
|
|
|
540
540
|
"""Returns the total vmem bytes used by the kernel."""
|
|
541
541
|
m_per_device = m // tp_size
|
|
542
542
|
n_per_device = n // tp_size
|
|
543
|
-
y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype)
|
|
543
|
+
y_vmem_bytes = (n_per_device * k * (dtypes.bit_width(y_dtype) if hasattr(
|
|
544
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(y_dtype)) // 8)
|
|
544
545
|
total_bytes = (
|
|
545
|
-
2 * m_per_device * k *
|
|
546
|
-
|
|
546
|
+
2 * m_per_device * k *
|
|
547
|
+
(dtypes.bit_width(x_dtype) if hasattr(dtypes, "bit_width") else
|
|
548
|
+
dtypes.itemsize_bits(x_dtype)) // 8 # x_vmem_scratch_ref
|
|
547
549
|
+ y_vmem_bytes # y_vmem_scratch_ref
|
|
548
|
-
+ 2 * m * bn *
|
|
550
|
+
+ 2 * m * bn *
|
|
551
|
+
(dtypes.bit_width(out_dtype) if hasattr(dtypes, "bit_width") else
|
|
552
|
+
dtypes.itemsize_bits(out_dtype)) // 8 # o_vmem_scratch_ref
|
|
549
553
|
+ acc_bytes # acc_vmem_scratch_ref, jnp.float32
|
|
550
554
|
)
|
|
551
555
|
return total_bytes
|
|
@@ -639,8 +643,10 @@ def all_gather_matmul(
|
|
|
639
643
|
# NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
|
|
640
644
|
if grid_k == 1:
|
|
641
645
|
acc_shape = (8, 128)
|
|
642
|
-
acc_bytes =
|
|
643
|
-
|
|
646
|
+
acc_bytes = (
|
|
647
|
+
acc_shape[0] *
|
|
648
|
+
acc_shape[1] * (dtypes.bit_width(jnp.float32) if hasattr(
|
|
649
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(jnp.float32)) // 8)
|
|
644
650
|
y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
|
|
645
651
|
estimated_vmem_bytes = get_vmem_estimate_bytes(
|
|
646
652
|
m,
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
"""All-gather matmul kernel's tuned block sizes."""
|
|
3
3
|
|
|
4
|
+
import re
|
|
5
|
+
|
|
4
6
|
import jax
|
|
5
7
|
|
|
6
8
|
# key:
|
|
@@ -32,8 +34,11 @@ def get_tpu_version() -> int:
|
|
|
32
34
|
return -1
|
|
33
35
|
if kind.endswith(' lite'):
|
|
34
36
|
kind = kind[:-len(' lite')]
|
|
35
|
-
|
|
36
|
-
|
|
37
|
+
|
|
38
|
+
# v6: "TPU v6"
|
|
39
|
+
# v7: "TPU7x"
|
|
40
|
+
assert kind[:3] == 'TPU', kind
|
|
41
|
+
return int(re.search(r'\d+', kind).group())
|
|
37
42
|
|
|
38
43
|
|
|
39
44
|
def get_key(
|