tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import torch
|
|
7
|
+
from torchax.interop import call_jax
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@jax.jit
|
|
11
|
+
def bgmv_jax(
|
|
12
|
+
inputs, # [num_tokens, hidden_size]
|
|
13
|
+
loras, # [num_loras, lora_rank, hidden_size]
|
|
14
|
+
idxs, # [num_tokens]
|
|
15
|
+
):
|
|
16
|
+
return jnp.einsum(
|
|
17
|
+
"td,tX,Xld->tl",
|
|
18
|
+
inputs,
|
|
19
|
+
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
|
|
20
|
+
loras,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def bgmv_torch(
|
|
25
|
+
inputs, # [num_tokens, hidden_size]
|
|
26
|
+
loras, # [num_loras, 1, lora_rank, hidden_size]
|
|
27
|
+
idxs, # [num_tokens]
|
|
28
|
+
): # [num_tokens, lora_rank]
|
|
29
|
+
# TODO(xiowei): use the below one_hot impl (added in https://github.com/pytorch/xla/pull/9523) after we upgrade torchax version.
|
|
30
|
+
# if len(loras.shape) == 4:
|
|
31
|
+
# loras = loras.squeeze(axis=1)
|
|
32
|
+
# return torch.einsum(
|
|
33
|
+
# "td,tX,Xld->tl",
|
|
34
|
+
# inputs,
|
|
35
|
+
# torch.nn.functional.one_hot(idxs.long(), loras.shape[0]),
|
|
36
|
+
# loras,
|
|
37
|
+
# ) # [num_tokens, lora_rank]
|
|
38
|
+
|
|
39
|
+
if len(loras.shape) == 4:
|
|
40
|
+
loras = loras.squeeze(axis=1)
|
|
41
|
+
return call_jax(bgmv_jax, inputs, loras, idxs)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def bgmv_shrink(
|
|
45
|
+
inputs: torch.Tensor,
|
|
46
|
+
lora_b_weights: torch.Tensor,
|
|
47
|
+
lora_indices_tensor: torch.Tensor,
|
|
48
|
+
scaling: float = 1.0,
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Args:
|
|
52
|
+
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
|
53
|
+
lora_b_weights (torch.Tensor): LoRA weights of shape
|
|
54
|
+
[max_loras, 1, max_lora_rank, hidden_size].
|
|
55
|
+
output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
|
|
56
|
+
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
|
57
|
+
indicating which LoRA matrix to use for each token.
|
|
58
|
+
scaling (float, optional): Scalar multiplier applied to the output.
|
|
59
|
+
"""
|
|
60
|
+
return scaling * bgmv_torch(inputs, lora_b_weights, lora_indices_tensor)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def bgmv_expand_slice(
|
|
64
|
+
inputs: torch.Tensor,
|
|
65
|
+
lora_b_weights: torch.Tensor,
|
|
66
|
+
output_tensor: torch.Tensor,
|
|
67
|
+
lora_indices_tensor: torch.Tensor,
|
|
68
|
+
slice_offset: int,
|
|
69
|
+
slice_size: int,
|
|
70
|
+
add_inputs: bool = True,
|
|
71
|
+
):
|
|
72
|
+
"""
|
|
73
|
+
Args:
|
|
74
|
+
inputs (torch.Tensor): Input tensor of shape [num_tokens, lora_rank].
|
|
75
|
+
|
|
76
|
+
lora_b_weights (torch.Tensor): LoRA weights of shape
|
|
77
|
+
[num_loras, 1, out_features, lora_rank].
|
|
78
|
+
|
|
79
|
+
output_tensor (torch.Tensor): output tensor of shape
|
|
80
|
+
[num_tokens, out_features * num_slices].
|
|
81
|
+
|
|
82
|
+
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
|
83
|
+
indicating which LoRA matrix to use for each token.
|
|
84
|
+
add_inputs (bool): Whether or not to add the input tensor to the output
|
|
85
|
+
tensor.
|
|
86
|
+
"""
|
|
87
|
+
outputs = bgmv_torch(inputs, lora_b_weights,
|
|
88
|
+
lora_indices_tensor) # [num_tokens, out_features]
|
|
89
|
+
|
|
90
|
+
# Create a padded tensor manually to avoid issues with F.pad on sharded tensors.
|
|
91
|
+
# This is a more robust way to handle padding in a distributed environment.
|
|
92
|
+
outputs_padded = torch.zeros_like(output_tensor)
|
|
93
|
+
outputs_padded[:, slice_offset:slice_offset + slice_size] = outputs
|
|
94
|
+
|
|
95
|
+
if add_inputs:
|
|
96
|
+
return output_tensor + outputs_padded
|
|
97
|
+
else:
|
|
98
|
+
return outputs_padded
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
import torchax
|
|
10
|
+
from vllm.lora.punica_wrapper.utils import convert_mapping
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
# avoid circuit import
|
|
14
|
+
from vllm.lora.layers import LoRAMapping
|
|
15
|
+
|
|
16
|
+
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
|
17
|
+
|
|
18
|
+
from tpu_inference.lora.torch_lora_ops import bgmv_expand_slice, bgmv_shrink
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PunicaWrapperTPU(PunicaWrapperBase):
|
|
22
|
+
"""
|
|
23
|
+
PunicaWrapperTPU is designed to manage and provide metadata for the punica
|
|
24
|
+
kernel. The main function is to maintain the state information for
|
|
25
|
+
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
|
26
|
+
|
|
27
|
+
It is created by get_punica_wrapper when we load_lora_model->create_lora_manager. Device is TPU.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
|
31
|
+
device: Union[torch.device, str], **kwargs):
|
|
32
|
+
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
|
33
|
+
device)
|
|
34
|
+
|
|
35
|
+
# PunicaWrapperBase defines some tensors with dtype=torch.int64, which
|
|
36
|
+
# isn't supported by the TPU. So convert those tensors to int32.
|
|
37
|
+
# Not all of them are used by the TPU so only convert the useful ones.
|
|
38
|
+
self._token_lora_indices = self._token_lora_indices.to(
|
|
39
|
+
dtype=torch.int32) # map from token to LoRA index.
|
|
40
|
+
self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
|
|
41
|
+
self._sampler_indices_padded = self._sampler_indices_padded.to(
|
|
42
|
+
dtype=torch.int32)
|
|
43
|
+
|
|
44
|
+
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
|
|
45
|
+
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def embeddings_indices(self) -> torch.Tensor:
|
|
49
|
+
"""
|
|
50
|
+
This property provides access to the indices used for lora embeddings,
|
|
51
|
+
specifically for VocabParallelEmbeddingWithLoRA.
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError(
|
|
54
|
+
"NYI: torch_punica_tpu.PunicaWrapperTPU.embeddings_indices.")
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def sampler_indices_padded(self) -> torch.Tensor:
|
|
58
|
+
"""
|
|
59
|
+
This property provides access to padded sampler indices.
|
|
60
|
+
"""
|
|
61
|
+
raise NotImplementedError(
|
|
62
|
+
"NYI: torch_punica_tpu.PunicaWrapperTPU.sampler_indices_padded.")
|
|
63
|
+
|
|
64
|
+
def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
|
|
65
|
+
x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
|
|
66
|
+
scale: float, **kwargs) -> Optional[torch.Tensor]:
|
|
67
|
+
"""
|
|
68
|
+
Performs GEMM for multiple slices of lora_a.
|
|
69
|
+
|
|
70
|
+
Semantics:
|
|
71
|
+
for i in range(len(lora_a_stacked)):
|
|
72
|
+
y[i] += (x @ lora_a_stacked[i]) * scale
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors. (n_slices, num_tokens, r)
|
|
76
|
+
x (torch.Tensor): Input tensor. (num_tokens, in_features)
|
|
77
|
+
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features)
|
|
78
|
+
scale (float): Scaling factor for the operation
|
|
79
|
+
"""
|
|
80
|
+
x = x.view(-1, x.shape[-1])
|
|
81
|
+
|
|
82
|
+
for slice_idx in range(len(lora_a_stacked)):
|
|
83
|
+
lora_s = lora_a_stacked[slice_idx]
|
|
84
|
+
y_s = bgmv_shrink(x, lora_s, self._get_token_lora_indices(x),
|
|
85
|
+
scale)
|
|
86
|
+
y[slice_idx, :, :] = y_s # type: ignore[index]
|
|
87
|
+
return y
|
|
88
|
+
|
|
89
|
+
def add_expand(self,
|
|
90
|
+
y: torch.Tensor,
|
|
91
|
+
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
|
|
92
|
+
lora_b_stacked: tuple[torch.Tensor, ...],
|
|
93
|
+
output_slices: tuple[int, ...],
|
|
94
|
+
offset_start: int = 0,
|
|
95
|
+
add_inputs=True,
|
|
96
|
+
**kwargs) -> torch.Tensor:
|
|
97
|
+
"""
|
|
98
|
+
Performs GEMM for multiple slices of lora_b.
|
|
99
|
+
|
|
100
|
+
Semantics:
|
|
101
|
+
for i in range(len(lora_b_stacked)):
|
|
102
|
+
slice = output_slices[i]
|
|
103
|
+
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
|
104
|
+
offset += slice
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
y (torch.Tensor): Output tensor. (num_tokens, out_features)
|
|
108
|
+
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors. (n_slices, num_tokens, r)
|
|
109
|
+
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
|
|
110
|
+
output_slices (tuple[int, ...]): Every slice's size
|
|
111
|
+
add_inputs (bool): Defaults to True.
|
|
112
|
+
"""
|
|
113
|
+
y_orig = y
|
|
114
|
+
y = y.view(-1, y.shape[-1])
|
|
115
|
+
offset_left = 0
|
|
116
|
+
|
|
117
|
+
for slice_idx in range(len(lora_b_stacked)):
|
|
118
|
+
y = bgmv_expand_slice(x[slice_idx], lora_b_stacked[slice_idx], y,
|
|
119
|
+
self._get_token_lora_indices(x[slice_idx]),
|
|
120
|
+
offset_left, output_slices[slice_idx],
|
|
121
|
+
add_inputs)
|
|
122
|
+
offset_left += output_slices[slice_idx]
|
|
123
|
+
return y.view(y_orig.shape)
|
|
124
|
+
|
|
125
|
+
def add_lora_embedding(self,
|
|
126
|
+
y: torch.Tensor,
|
|
127
|
+
x: torch.Tensor,
|
|
128
|
+
lora_b_stacked: torch.Tensor,
|
|
129
|
+
add_inputs: bool = True,
|
|
130
|
+
**kwargs) -> torch.Tensor:
|
|
131
|
+
"""
|
|
132
|
+
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
|
133
|
+
|
|
134
|
+
Semantics:
|
|
135
|
+
y += x @ lora_b_stacked
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
y (torch.Tensor): Output tensor.
|
|
139
|
+
x (torch.Tensor): Input tensor.
|
|
140
|
+
lora_b_stacked (torch.Tensor): lora_b's weights.
|
|
141
|
+
add_inputs (bool): Default to True.
|
|
142
|
+
"""
|
|
143
|
+
raise NotImplementedError(
|
|
144
|
+
"NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_embedding.")
|
|
145
|
+
|
|
146
|
+
def add_lora_linear(self,
|
|
147
|
+
y: torch.Tensor,
|
|
148
|
+
x: torch.Tensor,
|
|
149
|
+
lora_a_stacked: tuple[torch.Tensor, ...],
|
|
150
|
+
lora_b_stacked: tuple[torch.Tensor, ...],
|
|
151
|
+
scale: float,
|
|
152
|
+
output_slices: tuple[int, ...],
|
|
153
|
+
*,
|
|
154
|
+
buffer: Optional[tuple[torch.Tensor, ...]] = None,
|
|
155
|
+
**kwargs) -> torch.Tensor:
|
|
156
|
+
"""
|
|
157
|
+
Applicable to linear-related lora.
|
|
158
|
+
|
|
159
|
+
Semantics:
|
|
160
|
+
for i in range(len(lora_a_stacked)):
|
|
161
|
+
y[i] += (
|
|
162
|
+
x[i].unsqueeze(0)
|
|
163
|
+
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
|
164
|
+
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
|
165
|
+
* scale
|
|
166
|
+
).squeeze(0)
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
y (torch.Tensor): Output tensor (bs, out_features). Will not be changed in-place.
|
|
170
|
+
x (torch.Tensor): Input tensor (bs, in_features)
|
|
171
|
+
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight of length n_slices. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features)
|
|
172
|
+
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight of length n_slices. lora_b_stacked[i]: (max_loras, 1, out_features, max_lora_rank)
|
|
173
|
+
output_slices (tuple[int, ...]): Every slice's size.
|
|
174
|
+
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
|
178
|
+
|
|
179
|
+
if buffer is None:
|
|
180
|
+
max_lora_rank = lora_b_stacked[0].size(-1)
|
|
181
|
+
num_tokens = x.size(0)
|
|
182
|
+
buffer = torch.zeros(
|
|
183
|
+
(len(output_slices), num_tokens, max_lora_rank),
|
|
184
|
+
dtype=x.dtype,
|
|
185
|
+
device=x.device,
|
|
186
|
+
)
|
|
187
|
+
buffer = self.add_shrink(
|
|
188
|
+
buffer, x, lora_a_stacked, scale,
|
|
189
|
+
**kwargs) # (n_slices, num_tokens, max_lora_rank)
|
|
190
|
+
return self.add_expand(y,
|
|
191
|
+
buffer,
|
|
192
|
+
lora_b_stacked,
|
|
193
|
+
output_slices,
|
|
194
|
+
add_inputs=True,
|
|
195
|
+
**kwargs)
|
|
196
|
+
|
|
197
|
+
def add_lora_logits(self,
|
|
198
|
+
y: torch.Tensor,
|
|
199
|
+
x: torch.Tensor,
|
|
200
|
+
lora_a_stacked: torch.Tensor,
|
|
201
|
+
lora_b_stacked: torch.Tensor,
|
|
202
|
+
scale,
|
|
203
|
+
*,
|
|
204
|
+
buffer: Optional[torch.Tensor] = None,
|
|
205
|
+
**kwargs) -> torch.Tensor:
|
|
206
|
+
"""
|
|
207
|
+
Applies lora specifically for LogitsProcessorWithLoRA.
|
|
208
|
+
|
|
209
|
+
Semantics:
|
|
210
|
+
buffer = (x @ lora_a_stacked) * scale
|
|
211
|
+
y += buffer @ lora_b_stacked
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
y (torch.Tensor): Output tensor.
|
|
215
|
+
x (torch.Tensor): Input tensor.
|
|
216
|
+
lora_a_stacked (torch.Tensor): lora_a's weights.
|
|
217
|
+
lora_b_stacked (torch.Tensor):lora_b's weights.
|
|
218
|
+
scale (float): Scaling factor.
|
|
219
|
+
buffer (Optional[torch.Tensor]):Default to None.
|
|
220
|
+
"""
|
|
221
|
+
raise NotImplementedError(
|
|
222
|
+
"NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_logits.")
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def token_lora_indices(self) -> torch.Tensor:
|
|
226
|
+
"""
|
|
227
|
+
This property provides the lora indices corresponding to each token
|
|
228
|
+
in the batch. An index of -1 means no lora should be applied.
|
|
229
|
+
"""
|
|
230
|
+
with torchax.default_env():
|
|
231
|
+
token_lora_len = self.indices_len[0]
|
|
232
|
+
return self._token_lora_indices[:token_lora_len]
|
|
233
|
+
|
|
234
|
+
# This performs the same tensor ops as the base method, except it does them
|
|
235
|
+
# on the CPU then transfers the results to the TPU
|
|
236
|
+
def _update_base_metadata(
|
|
237
|
+
self,
|
|
238
|
+
mapping: "LoRAMapping",
|
|
239
|
+
lora_index_to_id: list[Optional[int]],
|
|
240
|
+
max_loras: int,
|
|
241
|
+
vocab_size: int,
|
|
242
|
+
):
|
|
243
|
+
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
|
244
|
+
# TODO: Should this happen inside mapping internally? If so how can we
|
|
245
|
+
# avoid having backend specific LoRAMapping classes?
|
|
246
|
+
mapping.prompt_mapping = self._pad_prompt_mapping(
|
|
247
|
+
mapping.prompt_mapping)
|
|
248
|
+
|
|
249
|
+
(
|
|
250
|
+
base_indices,
|
|
251
|
+
sampler_indices,
|
|
252
|
+
sampler_indices_padded,
|
|
253
|
+
embeddings_indices,
|
|
254
|
+
indices_len,
|
|
255
|
+
) = convert_mapping(
|
|
256
|
+
mapping,
|
|
257
|
+
lora_index_to_id,
|
|
258
|
+
max_loras,
|
|
259
|
+
vocab_size,
|
|
260
|
+
0, # extra_vocab_size
|
|
261
|
+
"cpu",
|
|
262
|
+
)
|
|
263
|
+
with torchax.default_env():
|
|
264
|
+
self._token_lora_indices = self._pad_to_shape(
|
|
265
|
+
base_indices, self._token_lora_indices.shape,
|
|
266
|
+
dims=1).to(self.device)
|
|
267
|
+
self._sampler_indices = self._pad_to_shape(
|
|
268
|
+
sampler_indices, self._sampler_indices.shape,
|
|
269
|
+
dims=1).to(self.device)
|
|
270
|
+
self._sampler_indices_padded = self._pad_to_shape(
|
|
271
|
+
sampler_indices_padded,
|
|
272
|
+
self._sampler_indices_padded.shape,
|
|
273
|
+
dims=1).to(self.device)
|
|
274
|
+
self._embeddings_indices = self._pad_to_shape(
|
|
275
|
+
embeddings_indices, self._embeddings_indices.shape,
|
|
276
|
+
dims=2).to(self.device)
|
|
277
|
+
self.indices_len[:] = indices_len
|
|
278
|
+
|
|
279
|
+
def _update_prefill_metadata(self,
|
|
280
|
+
token_lora_tensor: torch.Tensor) -> None:
|
|
281
|
+
with torchax.default_env():
|
|
282
|
+
self.batch_size = 1
|
|
283
|
+
self._lora_indices_per_batch[:self.
|
|
284
|
+
batch_size] = token_lora_tensor[:self.
|
|
285
|
+
batch_size].torch(
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def _pad_prompt_mapping(
|
|
289
|
+
self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
|
|
290
|
+
num_reqs = len(prompt_mapping)
|
|
291
|
+
|
|
292
|
+
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
|
|
293
|
+
# import
|
|
294
|
+
MIN_NUM_SEQS = 8
|
|
295
|
+
|
|
296
|
+
padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
|
|
297
|
+
pad_len = padded_num_reqs - num_reqs
|
|
298
|
+
|
|
299
|
+
padding = [-1] * pad_len
|
|
300
|
+
return tuple(list(prompt_mapping) + padding)
|
|
301
|
+
|
|
302
|
+
def _pad_to_shape(self, src, target_shape, dims=1):
|
|
303
|
+
if dims == 1:
|
|
304
|
+
pad_len = target_shape[0] - src.shape[0]
|
|
305
|
+
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
|
|
306
|
+
else:
|
|
307
|
+
pad_rows = target_shape[0] - src.shape[0]
|
|
308
|
+
pad_cols = target_shape[1] - src.shape[1]
|
|
309
|
+
return F.pad(src, (0, pad_cols, 0, pad_rows),
|
|
310
|
+
value=0).to(torch.int32)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|