tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,98 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import torch
7
+ from torchax.interop import call_jax
8
+
9
+
10
+ @jax.jit
11
+ def bgmv_jax(
12
+ inputs, # [num_tokens, hidden_size]
13
+ loras, # [num_loras, lora_rank, hidden_size]
14
+ idxs, # [num_tokens]
15
+ ):
16
+ return jnp.einsum(
17
+ "td,tX,Xld->tl",
18
+ inputs,
19
+ jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
20
+ loras,
21
+ )
22
+
23
+
24
+ def bgmv_torch(
25
+ inputs, # [num_tokens, hidden_size]
26
+ loras, # [num_loras, 1, lora_rank, hidden_size]
27
+ idxs, # [num_tokens]
28
+ ): # [num_tokens, lora_rank]
29
+ # TODO(xiowei): use the below one_hot impl (added in https://github.com/pytorch/xla/pull/9523) after we upgrade torchax version.
30
+ # if len(loras.shape) == 4:
31
+ # loras = loras.squeeze(axis=1)
32
+ # return torch.einsum(
33
+ # "td,tX,Xld->tl",
34
+ # inputs,
35
+ # torch.nn.functional.one_hot(idxs.long(), loras.shape[0]),
36
+ # loras,
37
+ # ) # [num_tokens, lora_rank]
38
+
39
+ if len(loras.shape) == 4:
40
+ loras = loras.squeeze(axis=1)
41
+ return call_jax(bgmv_jax, inputs, loras, idxs)
42
+
43
+
44
+ def bgmv_shrink(
45
+ inputs: torch.Tensor,
46
+ lora_b_weights: torch.Tensor,
47
+ lora_indices_tensor: torch.Tensor,
48
+ scaling: float = 1.0,
49
+ ):
50
+ """
51
+ Args:
52
+ inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
53
+ lora_b_weights (torch.Tensor): LoRA weights of shape
54
+ [max_loras, 1, max_lora_rank, hidden_size].
55
+ output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
56
+ lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
57
+ indicating which LoRA matrix to use for each token.
58
+ scaling (float, optional): Scalar multiplier applied to the output.
59
+ """
60
+ return scaling * bgmv_torch(inputs, lora_b_weights, lora_indices_tensor)
61
+
62
+
63
+ def bgmv_expand_slice(
64
+ inputs: torch.Tensor,
65
+ lora_b_weights: torch.Tensor,
66
+ output_tensor: torch.Tensor,
67
+ lora_indices_tensor: torch.Tensor,
68
+ slice_offset: int,
69
+ slice_size: int,
70
+ add_inputs: bool = True,
71
+ ):
72
+ """
73
+ Args:
74
+ inputs (torch.Tensor): Input tensor of shape [num_tokens, lora_rank].
75
+
76
+ lora_b_weights (torch.Tensor): LoRA weights of shape
77
+ [num_loras, 1, out_features, lora_rank].
78
+
79
+ output_tensor (torch.Tensor): output tensor of shape
80
+ [num_tokens, out_features * num_slices].
81
+
82
+ lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
83
+ indicating which LoRA matrix to use for each token.
84
+ add_inputs (bool): Whether or not to add the input tensor to the output
85
+ tensor.
86
+ """
87
+ outputs = bgmv_torch(inputs, lora_b_weights,
88
+ lora_indices_tensor) # [num_tokens, out_features]
89
+
90
+ # Create a padded tensor manually to avoid issues with F.pad on sharded tensors.
91
+ # This is a more robust way to handle padding in a distributed environment.
92
+ outputs_padded = torch.zeros_like(output_tensor)
93
+ outputs_padded[:, slice_offset:slice_offset + slice_size] = outputs
94
+
95
+ if add_inputs:
96
+ return output_tensor + outputs_padded
97
+ else:
98
+ return outputs_padded
@@ -0,0 +1,310 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import math
5
+ from typing import TYPE_CHECKING, Optional, Union
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchax
10
+ from vllm.lora.punica_wrapper.utils import convert_mapping
11
+
12
+ if TYPE_CHECKING:
13
+ # avoid circuit import
14
+ from vllm.lora.layers import LoRAMapping
15
+
16
+ from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
17
+
18
+ from tpu_inference.lora.torch_lora_ops import bgmv_expand_slice, bgmv_shrink
19
+
20
+
21
+ class PunicaWrapperTPU(PunicaWrapperBase):
22
+ """
23
+ PunicaWrapperTPU is designed to manage and provide metadata for the punica
24
+ kernel. The main function is to maintain the state information for
25
+ Multi-LoRA, and to provide the interface for the pytorch punica ops.
26
+
27
+ It is created by get_punica_wrapper when we load_lora_model->create_lora_manager. Device is TPU.
28
+ """
29
+
30
+ def __init__(self, max_num_batched_tokens: int, max_batches: int,
31
+ device: Union[torch.device, str], **kwargs):
32
+ PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
33
+ device)
34
+
35
+ # PunicaWrapperBase defines some tensors with dtype=torch.int64, which
36
+ # isn't supported by the TPU. So convert those tensors to int32.
37
+ # Not all of them are used by the TPU so only convert the useful ones.
38
+ self._token_lora_indices = self._token_lora_indices.to(
39
+ dtype=torch.int32) # map from token to LoRA index.
40
+ self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
41
+ self._sampler_indices_padded = self._sampler_indices_padded.to(
42
+ dtype=torch.int32)
43
+
44
+ def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
45
+ return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
46
+
47
+ @property
48
+ def embeddings_indices(self) -> torch.Tensor:
49
+ """
50
+ This property provides access to the indices used for lora embeddings,
51
+ specifically for VocabParallelEmbeddingWithLoRA.
52
+ """
53
+ raise NotImplementedError(
54
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.embeddings_indices.")
55
+
56
+ @property
57
+ def sampler_indices_padded(self) -> torch.Tensor:
58
+ """
59
+ This property provides access to padded sampler indices.
60
+ """
61
+ raise NotImplementedError(
62
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.sampler_indices_padded.")
63
+
64
+ def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
65
+ x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
66
+ scale: float, **kwargs) -> Optional[torch.Tensor]:
67
+ """
68
+ Performs GEMM for multiple slices of lora_a.
69
+
70
+ Semantics:
71
+ for i in range(len(lora_a_stacked)):
72
+ y[i] += (x @ lora_a_stacked[i]) * scale
73
+
74
+ Args:
75
+ y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors. (n_slices, num_tokens, r)
76
+ x (torch.Tensor): Input tensor. (num_tokens, in_features)
77
+ lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features)
78
+ scale (float): Scaling factor for the operation
79
+ """
80
+ x = x.view(-1, x.shape[-1])
81
+
82
+ for slice_idx in range(len(lora_a_stacked)):
83
+ lora_s = lora_a_stacked[slice_idx]
84
+ y_s = bgmv_shrink(x, lora_s, self._get_token_lora_indices(x),
85
+ scale)
86
+ y[slice_idx, :, :] = y_s # type: ignore[index]
87
+ return y
88
+
89
+ def add_expand(self,
90
+ y: torch.Tensor,
91
+ x: Union[tuple[torch.Tensor, ...], torch.Tensor],
92
+ lora_b_stacked: tuple[torch.Tensor, ...],
93
+ output_slices: tuple[int, ...],
94
+ offset_start: int = 0,
95
+ add_inputs=True,
96
+ **kwargs) -> torch.Tensor:
97
+ """
98
+ Performs GEMM for multiple slices of lora_b.
99
+
100
+ Semantics:
101
+ for i in range(len(lora_b_stacked)):
102
+ slice = output_slices[i]
103
+ y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
104
+ offset += slice
105
+
106
+ Args:
107
+ y (torch.Tensor): Output tensor. (num_tokens, out_features)
108
+ x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors. (n_slices, num_tokens, r)
109
+ lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
110
+ output_slices (tuple[int, ...]): Every slice's size
111
+ add_inputs (bool): Defaults to True.
112
+ """
113
+ y_orig = y
114
+ y = y.view(-1, y.shape[-1])
115
+ offset_left = 0
116
+
117
+ for slice_idx in range(len(lora_b_stacked)):
118
+ y = bgmv_expand_slice(x[slice_idx], lora_b_stacked[slice_idx], y,
119
+ self._get_token_lora_indices(x[slice_idx]),
120
+ offset_left, output_slices[slice_idx],
121
+ add_inputs)
122
+ offset_left += output_slices[slice_idx]
123
+ return y.view(y_orig.shape)
124
+
125
+ def add_lora_embedding(self,
126
+ y: torch.Tensor,
127
+ x: torch.Tensor,
128
+ lora_b_stacked: torch.Tensor,
129
+ add_inputs: bool = True,
130
+ **kwargs) -> torch.Tensor:
131
+ """
132
+ Applies lora specifically for VocabParallelEmbeddingWithLoRA.
133
+
134
+ Semantics:
135
+ y += x @ lora_b_stacked
136
+
137
+ Args:
138
+ y (torch.Tensor): Output tensor.
139
+ x (torch.Tensor): Input tensor.
140
+ lora_b_stacked (torch.Tensor): lora_b's weights.
141
+ add_inputs (bool): Default to True.
142
+ """
143
+ raise NotImplementedError(
144
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_embedding.")
145
+
146
+ def add_lora_linear(self,
147
+ y: torch.Tensor,
148
+ x: torch.Tensor,
149
+ lora_a_stacked: tuple[torch.Tensor, ...],
150
+ lora_b_stacked: tuple[torch.Tensor, ...],
151
+ scale: float,
152
+ output_slices: tuple[int, ...],
153
+ *,
154
+ buffer: Optional[tuple[torch.Tensor, ...]] = None,
155
+ **kwargs) -> torch.Tensor:
156
+ """
157
+ Applicable to linear-related lora.
158
+
159
+ Semantics:
160
+ for i in range(len(lora_a_stacked)):
161
+ y[i] += (
162
+ x[i].unsqueeze(0)
163
+ @ lora_a_stacked[indices[i], layer_idx, :, :]
164
+ @ lora_b_stacked[indices[i], layer_idx, :, :]
165
+ * scale
166
+ ).squeeze(0)
167
+
168
+ Args:
169
+ y (torch.Tensor): Output tensor (bs, out_features). Will not be changed in-place.
170
+ x (torch.Tensor): Input tensor (bs, in_features)
171
+ lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight of length n_slices. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features)
172
+ lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight of length n_slices. lora_b_stacked[i]: (max_loras, 1, out_features, max_lora_rank)
173
+ output_slices (tuple[int, ...]): Every slice's size.
174
+ buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
175
+ """
176
+
177
+ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
178
+
179
+ if buffer is None:
180
+ max_lora_rank = lora_b_stacked[0].size(-1)
181
+ num_tokens = x.size(0)
182
+ buffer = torch.zeros(
183
+ (len(output_slices), num_tokens, max_lora_rank),
184
+ dtype=x.dtype,
185
+ device=x.device,
186
+ )
187
+ buffer = self.add_shrink(
188
+ buffer, x, lora_a_stacked, scale,
189
+ **kwargs) # (n_slices, num_tokens, max_lora_rank)
190
+ return self.add_expand(y,
191
+ buffer,
192
+ lora_b_stacked,
193
+ output_slices,
194
+ add_inputs=True,
195
+ **kwargs)
196
+
197
+ def add_lora_logits(self,
198
+ y: torch.Tensor,
199
+ x: torch.Tensor,
200
+ lora_a_stacked: torch.Tensor,
201
+ lora_b_stacked: torch.Tensor,
202
+ scale,
203
+ *,
204
+ buffer: Optional[torch.Tensor] = None,
205
+ **kwargs) -> torch.Tensor:
206
+ """
207
+ Applies lora specifically for LogitsProcessorWithLoRA.
208
+
209
+ Semantics:
210
+ buffer = (x @ lora_a_stacked) * scale
211
+ y += buffer @ lora_b_stacked
212
+
213
+ Args:
214
+ y (torch.Tensor): Output tensor.
215
+ x (torch.Tensor): Input tensor.
216
+ lora_a_stacked (torch.Tensor): lora_a's weights.
217
+ lora_b_stacked (torch.Tensor):lora_b's weights.
218
+ scale (float): Scaling factor.
219
+ buffer (Optional[torch.Tensor]):Default to None.
220
+ """
221
+ raise NotImplementedError(
222
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_logits.")
223
+
224
+ @property
225
+ def token_lora_indices(self) -> torch.Tensor:
226
+ """
227
+ This property provides the lora indices corresponding to each token
228
+ in the batch. An index of -1 means no lora should be applied.
229
+ """
230
+ with torchax.default_env():
231
+ token_lora_len = self.indices_len[0]
232
+ return self._token_lora_indices[:token_lora_len]
233
+
234
+ # This performs the same tensor ops as the base method, except it does them
235
+ # on the CPU then transfers the results to the TPU
236
+ def _update_base_metadata(
237
+ self,
238
+ mapping: "LoRAMapping",
239
+ lora_index_to_id: list[Optional[int]],
240
+ max_loras: int,
241
+ vocab_size: int,
242
+ ):
243
+ # Pad the prompt mapping to avoid running into recompiles on the TPU
244
+ # TODO: Should this happen inside mapping internally? If so how can we
245
+ # avoid having backend specific LoRAMapping classes?
246
+ mapping.prompt_mapping = self._pad_prompt_mapping(
247
+ mapping.prompt_mapping)
248
+
249
+ (
250
+ base_indices,
251
+ sampler_indices,
252
+ sampler_indices_padded,
253
+ embeddings_indices,
254
+ indices_len,
255
+ ) = convert_mapping(
256
+ mapping,
257
+ lora_index_to_id,
258
+ max_loras,
259
+ vocab_size,
260
+ 0, # extra_vocab_size
261
+ "cpu",
262
+ )
263
+ with torchax.default_env():
264
+ self._token_lora_indices = self._pad_to_shape(
265
+ base_indices, self._token_lora_indices.shape,
266
+ dims=1).to(self.device)
267
+ self._sampler_indices = self._pad_to_shape(
268
+ sampler_indices, self._sampler_indices.shape,
269
+ dims=1).to(self.device)
270
+ self._sampler_indices_padded = self._pad_to_shape(
271
+ sampler_indices_padded,
272
+ self._sampler_indices_padded.shape,
273
+ dims=1).to(self.device)
274
+ self._embeddings_indices = self._pad_to_shape(
275
+ embeddings_indices, self._embeddings_indices.shape,
276
+ dims=2).to(self.device)
277
+ self.indices_len[:] = indices_len
278
+
279
+ def _update_prefill_metadata(self,
280
+ token_lora_tensor: torch.Tensor) -> None:
281
+ with torchax.default_env():
282
+ self.batch_size = 1
283
+ self._lora_indices_per_batch[:self.
284
+ batch_size] = token_lora_tensor[:self.
285
+ batch_size].torch(
286
+ )
287
+
288
+ def _pad_prompt_mapping(
289
+ self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
290
+ num_reqs = len(prompt_mapping)
291
+
292
+ # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
293
+ # import
294
+ MIN_NUM_SEQS = 8
295
+
296
+ padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
297
+ pad_len = padded_num_reqs - num_reqs
298
+
299
+ padding = [-1] * pad_len
300
+ return tuple(list(prompt_mapping) + padding)
301
+
302
+ def _pad_to_shape(self, src, target_shape, dims=1):
303
+ if dims == 1:
304
+ pad_len = target_shape[0] - src.shape[0]
305
+ return F.pad(src, (0, pad_len), value=0).to(torch.int32)
306
+ else:
307
+ pad_rows = target_shape[0] - src.shape[0]
308
+ pad_cols = target_shape[1] - src.shape[1]
309
+ return F.pad(src, (0, pad_cols, 0, pad_rows),
310
+ value=0).to(torch.int32)
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.