tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
# MXFP4 constants
|
|
7
|
+
MXFP4_BLOCK_SIZE: int = 32
|
|
8
|
+
# Exponent-only e8m0 scale bias used by MXFP4 scales
|
|
9
|
+
MXFP4_SCALE_BIAS: int = 127
|
|
10
|
+
# Name used in config.json quantization_config["quant_method"]
|
|
11
|
+
MXFP4_QUANT_METHOD: str = "mxfp4"
|
|
12
|
+
|
|
13
|
+
# Precompute a small LUT once; move to device on demand (cheap 16-element copy)
|
|
14
|
+
FP4_LUT = torch.tensor(
|
|
15
|
+
[
|
|
16
|
+
0.0,
|
|
17
|
+
0.5,
|
|
18
|
+
1.0,
|
|
19
|
+
1.5,
|
|
20
|
+
2.0,
|
|
21
|
+
3.0,
|
|
22
|
+
4.0,
|
|
23
|
+
6.0, # 0b0000-0b0111
|
|
24
|
+
-0.0,
|
|
25
|
+
-0.5,
|
|
26
|
+
-1.0,
|
|
27
|
+
-1.5,
|
|
28
|
+
-2.0,
|
|
29
|
+
-3.0,
|
|
30
|
+
-4.0,
|
|
31
|
+
-6.0, # 0b1000-0b1111
|
|
32
|
+
],
|
|
33
|
+
dtype=torch.float32)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def unpack_mxfp4(packed: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
"""Unpack uint8 (..., 16) -> fp4 values (..., 32) using low->high nibble order.
|
|
38
|
+
|
|
39
|
+
Returns float32 values corresponding to FP4 codebook entries.
|
|
40
|
+
"""
|
|
41
|
+
assert packed.dtype == torch.uint8
|
|
42
|
+
low = packed & 0x0F
|
|
43
|
+
high = (packed >> 4) & 0x0F
|
|
44
|
+
idx = torch.stack([low, high], dim=-1).flatten(-2)
|
|
45
|
+
lut = FP4_LUT.to(packed.device)
|
|
46
|
+
return lut[idx.long()]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def e8m0_to_fp32(u8: torch.Tensor) -> torch.Tensor:
|
|
50
|
+
"""Convert e8m0 uint8 exponents to power-of-two scales using MXFP4_SCALE_BIAS.
|
|
51
|
+
|
|
52
|
+
Uses ldexp for exact power-of-two scaling: 1.0 * 2**(u8 - bias).
|
|
53
|
+
"""
|
|
54
|
+
exponents = (u8.to(torch.int32) - int(MXFP4_SCALE_BIAS)).to(torch.int32)
|
|
55
|
+
ones = torch.ones_like(u8, dtype=torch.float32)
|
|
56
|
+
return torch.ldexp(ones, exponents)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def dequant_mxfp4_to_bf16(blocks_u8: torch.Tensor,
|
|
60
|
+
scales_u8: torch.Tensor) -> torch.Tensor:
|
|
61
|
+
"""Dequantize MXFP4 blocks/scales into bfloat16 values.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte holds 2 FP4 codes.
|
|
65
|
+
scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per 32-value block.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
torch.bfloat16 tensor with last logical dimension K = Kb * 32.
|
|
69
|
+
"""
|
|
70
|
+
if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}"
|
|
73
|
+
)
|
|
74
|
+
# Unpack FP4 codes to float32 values [..., Kb, 32]
|
|
75
|
+
fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32)
|
|
76
|
+
# Compute power-of-two scales and apply per block
|
|
77
|
+
scales = e8m0_to_fp32(scales_u8).unsqueeze(-1) # (..., Kb, 1)
|
|
78
|
+
full = (fp4_vals * scales).reshape(*fp4_vals.shape[:-2],
|
|
79
|
+
fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE)
|
|
80
|
+
return full.to(torch.bfloat16)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def unpack_mxfp4_to_fp32(
|
|
84
|
+
blocks_u8: torch.Tensor,
|
|
85
|
+
scales_u8: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
86
|
+
"""Decode MXFP4 packed blocks and e8m0 scales to float32 codes and scales.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
blocks_u8: uint8 tensor shaped [..., Kb, 16], each byte packs two FP4 codes.
|
|
90
|
+
scales_u8: uint8 tensor shaped [..., Kb], exponent-only e8m0 per block.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
(codes_fp32, scales_fp32), where
|
|
94
|
+
- codes_fp32 has shape [..., Kb*32] and dtype float32
|
|
95
|
+
- scales_fp32 has shape [..., Kb] and dtype float32
|
|
96
|
+
"""
|
|
97
|
+
if blocks_u8.dtype != torch.uint8 or scales_u8.dtype != torch.uint8:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"Expected uint8 inputs, got blocks={blocks_u8.dtype}, scales={scales_u8.dtype}"
|
|
100
|
+
)
|
|
101
|
+
fp4_vals = unpack_mxfp4(blocks_u8) # (..., Kb, 32) float32
|
|
102
|
+
codes_fp32 = fp4_vals.reshape(*fp4_vals.shape[:-2],
|
|
103
|
+
fp4_vals.shape[-2] * MXFP4_BLOCK_SIZE)
|
|
104
|
+
scales_fp32 = e8m0_to_fp32(scales_u8) # (..., Kb) float32
|
|
105
|
+
return codes_fp32, scales_fp32
|