tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""All-gather matmul kernel's tuned block sizes."""
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
|
|
6
|
+
# key:
|
|
7
|
+
# - tpu_version
|
|
8
|
+
# - m
|
|
9
|
+
# - n
|
|
10
|
+
# - k
|
|
11
|
+
# - dtype
|
|
12
|
+
# - tp_size
|
|
13
|
+
# value:
|
|
14
|
+
# - bn
|
|
15
|
+
# - bk
|
|
16
|
+
TUNED_BLOCK_SIZES = {
|
|
17
|
+
# go/keep-sorted start
|
|
18
|
+
(6, 1024, 51200, 5120, 'bfloat16', 8): (6400, 2560),
|
|
19
|
+
(6, 1024, 57344, 8192, 'bfloat16', 8): (7168, 8192),
|
|
20
|
+
(6, 2048, 51200, 5120, 'bfloat16', 8): (1280, 5120),
|
|
21
|
+
(6, 2048, 57344, 8192, 'bfloat16', 8): (1024, 8192),
|
|
22
|
+
(6, 4096, 51200, 5120, 'bfloat16', 8): (3200, 5120),
|
|
23
|
+
(6, 8192, 51200, 5120, 'bfloat16', 8): (1280, 5120),
|
|
24
|
+
# go/keep-sorted end
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_tpu_version() -> int:
|
|
29
|
+
"""Returns the numeric version of the TPU, or -1 if not on TPU."""
|
|
30
|
+
kind = jax.devices()[0].device_kind
|
|
31
|
+
if 'TPU' not in kind:
|
|
32
|
+
return -1
|
|
33
|
+
if kind.endswith(' lite'):
|
|
34
|
+
kind = kind[:-len(' lite')]
|
|
35
|
+
assert kind[:-1] == 'TPU v', kind
|
|
36
|
+
return int(kind[-1])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_key(
|
|
40
|
+
m,
|
|
41
|
+
n,
|
|
42
|
+
k,
|
|
43
|
+
dtype,
|
|
44
|
+
tp_size,
|
|
45
|
+
):
|
|
46
|
+
"""Returns the key for the given parameters."""
|
|
47
|
+
return (
|
|
48
|
+
get_tpu_version(),
|
|
49
|
+
m,
|
|
50
|
+
n,
|
|
51
|
+
k,
|
|
52
|
+
dtype,
|
|
53
|
+
tp_size,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_tuned_block_sizes(m, n, k, dtype_name, tp_size):
|
|
58
|
+
"""Returns the tuned block sizes for the given parameters."""
|
|
59
|
+
key = get_key(m, n, k, dtype_name, tp_size)
|
|
60
|
+
return TUNED_BLOCK_SIZES.get(key, (None, None))
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""utilities for collective kernels."""
|
|
3
|
+
|
|
4
|
+
import functools
|
|
5
|
+
|
|
6
|
+
from jax.experimental import pallas as pl
|
|
7
|
+
from jax.experimental.pallas import tpu as pltpu
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def local_barrier(left_neighbor, right_neighbor, double_barrier=True):
|
|
11
|
+
"""Performs a barrier with neighbors on the global barrier semaphore.
|
|
12
|
+
|
|
13
|
+
Optionally performs a second barrier, which prevents a potential race
|
|
14
|
+
when reusing the same collective_id across kernel invocations.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
left_neighbor: Left neighbor device id.
|
|
18
|
+
right_neighbor: Right neighbor device id.
|
|
19
|
+
double_barrier: Whether to perform a second barrier.
|
|
20
|
+
"""
|
|
21
|
+
barrier_sem = pltpu.get_barrier_semaphore()
|
|
22
|
+
for neighbor in [left_neighbor, right_neighbor]:
|
|
23
|
+
pltpu.semaphore_signal(
|
|
24
|
+
barrier_sem,
|
|
25
|
+
inc=1,
|
|
26
|
+
device_id=(neighbor, ),
|
|
27
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
28
|
+
)
|
|
29
|
+
pltpu.semaphore_wait(barrier_sem, 2)
|
|
30
|
+
if double_barrier:
|
|
31
|
+
# The double-barrier prevents a race condition where one neighbor can
|
|
32
|
+
# re-enter the kernel again on a subsequent call and increment the
|
|
33
|
+
# barrier semaphore a second time. This would unblock the current device
|
|
34
|
+
# even if the other neighbor is not ready yet.
|
|
35
|
+
# To implement a double-barrier, we stack-allocate a second REGULAR
|
|
36
|
+
# semaphore using run_scoped.
|
|
37
|
+
@functools.partial(pl.run_scoped,
|
|
38
|
+
second_barrier=pltpu.SemaphoreType.REGULAR)
|
|
39
|
+
def _(second_barrier):
|
|
40
|
+
for neighbor in [left_neighbor, right_neighbor]:
|
|
41
|
+
pltpu.semaphore_signal(
|
|
42
|
+
second_barrier,
|
|
43
|
+
inc=1,
|
|
44
|
+
device_id=(neighbor, ),
|
|
45
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
46
|
+
)
|
|
47
|
+
pltpu.semaphore_wait(second_barrier, 2)
|
|
File without changes
|