tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,103 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torchax.interop import call_jax
9
+
10
+
11
+ @jax.jit
12
+ def bgmv_jax(
13
+ inputs, # [num_tokens, hidden_size]
14
+ loras, # [num_loras, lora_rank, hidden_size]
15
+ idxs, # [num_tokens]
16
+ ):
17
+ return jnp.einsum(
18
+ "td,tX,Xld->tl",
19
+ inputs,
20
+ jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
21
+ loras,
22
+ )
23
+
24
+
25
+ def bgmv_torch(
26
+ inputs, # [num_tokens, hidden_size]
27
+ loras, # [num_loras, 1, lora_rank, hidden_size]
28
+ idxs, # [num_tokens]
29
+ ): # [num_tokens, lora_rank]
30
+ # TODO(xiowei): use the below one_hot impl (added in https://github.com/pytorch/xla/pull/9523) after we upgrade torchax version.
31
+ # if len(loras.shape) == 4:
32
+ # loras = loras.squeeze(axis=1)
33
+ # return torch.einsum(
34
+ # "td,tX,Xld->tl",
35
+ # inputs,
36
+ # torch.nn.functional.one_hot(idxs.long(), loras.shape[0]),
37
+ # loras,
38
+ # ) # [num_tokens, lora_rank]
39
+
40
+ if len(loras.shape) == 4:
41
+ loras = loras.squeeze(axis=1)
42
+ return call_jax(bgmv_jax, inputs, loras, idxs)
43
+
44
+
45
+ def bgmv_shrink(
46
+ inputs: torch.Tensor,
47
+ lora_b_weights: torch.Tensor,
48
+ lora_indices_tensor: torch.Tensor,
49
+ scaling: float = 1.0,
50
+ ):
51
+ """
52
+ Args:
53
+ inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
54
+ lora_b_weights (torch.Tensor): LoRA weights of shape
55
+ [max_loras, 1, max_lora_rank, hidden_size].
56
+ output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
57
+ lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
58
+ indicating which LoRA matrix to use for each token.
59
+ scaling (float, optional): Scalar multiplier applied to the output.
60
+ """
61
+ return scaling * bgmv_torch(inputs, lora_b_weights, lora_indices_tensor)
62
+
63
+
64
+ def bgmv_expand_slice(
65
+ inputs: torch.Tensor,
66
+ lora_b_weights: torch.Tensor,
67
+ output_tensor: torch.Tensor,
68
+ lora_indices_tensor: torch.Tensor,
69
+ slice_offset: int,
70
+ slice_size: int,
71
+ add_inputs: bool = True,
72
+ ):
73
+ """
74
+ Args:
75
+ inputs (torch.Tensor): Input tensor of shape [num_tokens, lora_rank].
76
+
77
+ lora_b_weights (torch.Tensor): LoRA weights of shape
78
+ [num_loras, 1, out_features, lora_rank].
79
+
80
+ output_tensor (torch.Tensor): output tensor of shape
81
+ [num_tokens, out_features * num_slices].
82
+
83
+ lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
84
+ indicating which LoRA matrix to use for each token.
85
+ add_inputs (bool): Whether or not to add the input tensor to the output
86
+ tensor.
87
+ """
88
+ outputs = bgmv_torch(inputs, lora_b_weights, lora_indices_tensor)
89
+
90
+ outputs = F.pad(
91
+ outputs,
92
+ (
93
+ slice_offset,
94
+ output_tensor.shape[1] - (slice_offset + slice_size),
95
+ 0,
96
+ 0,
97
+ ),
98
+ )
99
+
100
+ if add_inputs:
101
+ return output_tensor + outputs
102
+ else:
103
+ return outputs
@@ -0,0 +1,311 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import math
5
+ from typing import TYPE_CHECKING, Optional, Union
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchax
10
+ from vllm.lora.punica_wrapper.utils import convert_mapping
11
+
12
+ if TYPE_CHECKING:
13
+ # avoid circuit import
14
+ from vllm.lora.layers import LoRAMapping
15
+
16
+ from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
17
+
18
+ from tpu_inference.lora.torch_lora_ops import bgmv_expand_slice, bgmv_shrink
19
+
20
+
21
+ class PunicaWrapperTPU(PunicaWrapperBase):
22
+ """
23
+ PunicaWrapperTPU is designed to manage and provide metadata for the punica
24
+ kernel. The main function is to maintain the state information for
25
+ Multi-LoRA, and to provide the interface for the pytorch punica ops.
26
+
27
+ It is created by get_punica_wrapper when we load_lora_model->create_lora_manager. Device is TPU.
28
+ """
29
+
30
+ def __init__(self, max_num_batched_tokens: int, max_batches: int,
31
+ device: Union[torch.device, str], **kwargs):
32
+ PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
33
+ device)
34
+
35
+ # PunicaWrapperBase defines some tensors with dtype=torch.int64, which
36
+ # isn't supported by the TPU. So convert those tensors to int32.
37
+ # Not all of them are used by the TPU so only convert the useful ones.
38
+ self._token_lora_indices = self._token_lora_indices.to(
39
+ dtype=torch.int32) # map from token to LoRA index.
40
+ self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
41
+ self._sampler_indices_padded = self._sampler_indices_padded.to(
42
+ dtype=torch.int32)
43
+
44
+ def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
45
+ return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
46
+
47
+ @property
48
+ def embeddings_indices(self) -> torch.Tensor:
49
+ """
50
+ This property provides access to the indices used for lora embeddings,
51
+ specifically for VocabParallelEmbeddingWithLoRA.
52
+ """
53
+ raise NotImplementedError(
54
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.embeddings_indices.")
55
+
56
+ @property
57
+ def sampler_indices_padded(self) -> torch.Tensor:
58
+ """
59
+ This property provides access to padded sampler indices.
60
+ """
61
+ raise NotImplementedError(
62
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.sampler_indices_padded.")
63
+
64
+ def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
65
+ x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
66
+ scale: float, **kwargs) -> Optional[torch.Tensor]:
67
+ """
68
+ Performs GEMM for multiple slices of lora_a.
69
+
70
+ Semantics:
71
+ for i in range(len(lora_a_stacked)):
72
+ y[i] += (x @ lora_a_stacked[i]) * scale
73
+
74
+ Args:
75
+ y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors. (n_slices, num_tokens, r)
76
+ x (torch.Tensor): Input tensor. (num_tokens, in_features)
77
+ lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features)
78
+ scale (float): Scaling factor for the operation
79
+ """
80
+ x = x.view(-1, x.shape[-1])
81
+
82
+ for slice_idx in range(len(lora_a_stacked)):
83
+ lora_s = lora_a_stacked[slice_idx]
84
+ y_s = bgmv_shrink(x, lora_s, self._get_token_lora_indices(x),
85
+ scale)
86
+ y[slice_idx, :, :] = y_s # type: ignore[index]
87
+ return y
88
+
89
+ def add_expand(self,
90
+ y: torch.Tensor,
91
+ x: Union[tuple[torch.Tensor, ...], torch.Tensor],
92
+ lora_b_stacked: tuple[torch.Tensor, ...],
93
+ output_slices: tuple[int, ...],
94
+ offset_start: int = 0,
95
+ add_inputs=True,
96
+ **kwargs) -> torch.Tensor:
97
+ """
98
+ Performs GEMM for multiple slices of lora_b.
99
+
100
+ Semantics:
101
+ for i in range(len(lora_b_stacked)):
102
+ slice = output_slices[i]
103
+ y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
104
+ offset += slice
105
+
106
+ Args:
107
+ y (torch.Tensor): Output tensor. (num_tokens, out_features)
108
+ x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors. (n_slices, num_tokens, r)
109
+ lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
110
+ output_slices (tuple[int, ...]): Every slice's size
111
+ add_inputs (bool): Defaults to True.
112
+ """
113
+ y_orig = y
114
+ y = y.view(-1, y.shape[-1])
115
+ offset_left = 0
116
+
117
+ for slice_idx in range(len(lora_b_stacked)):
118
+ y = bgmv_expand_slice(x[slice_idx], lora_b_stacked[slice_idx], y,
119
+ self._get_token_lora_indices(x[slice_idx]),
120
+ offset_left, output_slices[slice_idx],
121
+ add_inputs)
122
+ offset_left += output_slices[slice_idx]
123
+ return y.view(y_orig.shape)
124
+
125
+ def add_lora_embedding(self,
126
+ y: torch.Tensor,
127
+ x: torch.Tensor,
128
+ lora_b_stacked: torch.Tensor,
129
+ add_inputs: bool = True,
130
+ **kwargs) -> torch.Tensor:
131
+ """
132
+ Applies lora specifically for VocabParallelEmbeddingWithLoRA.
133
+
134
+ Semantics:
135
+ y += x @ lora_b_stacked
136
+
137
+ Args:
138
+ y (torch.Tensor): Output tensor.
139
+ x (torch.Tensor): Input tensor.
140
+ lora_b_stacked (torch.Tensor): lora_b's weights.
141
+ add_inputs (bool): Default to True.
142
+ """
143
+ raise NotImplementedError(
144
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_embedding.")
145
+
146
+ def add_lora_linear(self,
147
+ y: torch.Tensor,
148
+ x: torch.Tensor,
149
+ lora_a_stacked: tuple[torch.Tensor, ...],
150
+ lora_b_stacked: tuple[torch.Tensor, ...],
151
+ scale: float,
152
+ output_slices: tuple[int, ...],
153
+ *,
154
+ buffer: Optional[tuple[torch.Tensor, ...]] = None,
155
+ **kwargs) -> torch.Tensor:
156
+ """
157
+ Applicable to linear-related lora.
158
+
159
+ Semantics:
160
+ for i in range(len(lora_a_stacked)):
161
+ y[i] += (
162
+ x[i].unsqueeze(0)
163
+ @ lora_a_stacked[indices[i], layer_idx, :, :]
164
+ @ lora_b_stacked[indices[i], layer_idx, :, :]
165
+ * scale
166
+ ).squeeze(0)
167
+
168
+ Args:
169
+ y (torch.Tensor): Output tensor (bs, out_features). Will not be changed in-place.
170
+ x (torch.Tensor): Input tensor (bs, in_features)
171
+ lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight of length n_slices. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features)
172
+ lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight of length n_slices. lora_b_stacked[i]: (max_loras, 1, out_features, max_lora_rank)
173
+ output_slices (tuple[int, ...]): Every slice's size.
174
+ buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
175
+ """
176
+
177
+ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
178
+
179
+ if buffer is None:
180
+ max_lora_rank = lora_b_stacked[0].size(-1)
181
+ num_tokens = x.size(0)
182
+ buffer = torch.zeros(
183
+ (len(output_slices), num_tokens, max_lora_rank),
184
+ dtype=x.dtype,
185
+ device=x.device,
186
+ )
187
+ buffer = self.add_shrink(
188
+ buffer, x, lora_a_stacked, scale,
189
+ **kwargs) # (n_slices, num_tokens, max_lora_rank)
190
+ return self.add_expand(y,
191
+ buffer,
192
+ lora_b_stacked,
193
+ output_slices,
194
+ add_inputs=True,
195
+ **kwargs)
196
+
197
+ def add_lora_logits(self,
198
+ y: torch.Tensor,
199
+ x: torch.Tensor,
200
+ lora_a_stacked: torch.Tensor,
201
+ lora_b_stacked: torch.Tensor,
202
+ scale,
203
+ *,
204
+ buffer: Optional[torch.Tensor] = None,
205
+ **kwargs) -> torch.Tensor:
206
+ """
207
+ Applies lora specifically for LogitsProcessorWithLoRA.
208
+
209
+ Semantics:
210
+ buffer = (x @ lora_a_stacked) * scale
211
+ y += buffer @ lora_b_stacked
212
+
213
+ Args:
214
+ y (torch.Tensor): Output tensor.
215
+ x (torch.Tensor): Input tensor.
216
+ lora_a_stacked (torch.Tensor): lora_a's weights.
217
+ lora_b_stacked (torch.Tensor):lora_b's weights.
218
+ scale (float): Scaling factor.
219
+ buffer (Optional[torch.Tensor]):Default to None.
220
+ """
221
+ raise NotImplementedError(
222
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_logits.")
223
+
224
+ @property
225
+ def token_lora_indices(self) -> torch.Tensor:
226
+ """
227
+ This property provides the lora indices corresponding to each token
228
+ in the batch. An index of -1 means no lora should be applied.
229
+ """
230
+ with torchax.default_env():
231
+ token_lora_len = self.indices_len[0]
232
+ return self._token_lora_indices[:token_lora_len]
233
+
234
+ # This performs the same tensor ops as the base method, except it does them
235
+ # on the CPU then transfers the results to the TPU
236
+ def _update_base_metadata(
237
+ self,
238
+ mapping: "LoRAMapping",
239
+ lora_index_to_id: list[Optional[int]],
240
+ max_loras: int,
241
+ vocab_size: int,
242
+ extra_vocab_size: int,
243
+ ):
244
+ # Pad the prompt mapping to avoid running into recompiles on the TPU
245
+ # TODO: Should this happen inside mapping internally? If so how can we
246
+ # avoid having backend specific LoRAMapping classes?
247
+ mapping.prompt_mapping = self._pad_prompt_mapping(
248
+ mapping.prompt_mapping)
249
+
250
+ (
251
+ base_indices,
252
+ sampler_indices,
253
+ sampler_indices_padded,
254
+ embeddings_indices,
255
+ indices_len,
256
+ ) = convert_mapping(
257
+ mapping,
258
+ lora_index_to_id,
259
+ max_loras,
260
+ vocab_size,
261
+ extra_vocab_size,
262
+ "cpu",
263
+ )
264
+ with torchax.default_env():
265
+ self._token_lora_indices = self._pad_to_shape(
266
+ base_indices, self._token_lora_indices.shape,
267
+ dims=1).to(self.device)
268
+ self._sampler_indices = self._pad_to_shape(
269
+ sampler_indices, self._sampler_indices.shape,
270
+ dims=1).to(self.device)
271
+ self._sampler_indices_padded = self._pad_to_shape(
272
+ sampler_indices_padded,
273
+ self._sampler_indices_padded.shape,
274
+ dims=1).to(self.device)
275
+ self._embeddings_indices = self._pad_to_shape(
276
+ embeddings_indices, self._embeddings_indices.shape,
277
+ dims=2).to(self.device)
278
+ self.indices_len[:] = indices_len
279
+
280
+ def _update_prefill_metadata(self,
281
+ token_lora_tensor: torch.Tensor) -> None:
282
+ with torchax.default_env():
283
+ self.batch_size = 1
284
+ self._lora_indices_per_batch[:self.
285
+ batch_size] = token_lora_tensor[:self.
286
+ batch_size].torch(
287
+ )
288
+
289
+ def _pad_prompt_mapping(
290
+ self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
291
+ num_reqs = len(prompt_mapping)
292
+
293
+ # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
294
+ # import
295
+ MIN_NUM_SEQS = 8
296
+
297
+ padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
298
+ pad_len = padded_num_reqs - num_reqs
299
+
300
+ padding = [-1] * pad_len
301
+ return tuple(list(prompt_mapping) + padding)
302
+
303
+ def _pad_to_shape(self, src, target_shape, dims=1):
304
+ if dims == 1:
305
+ pad_len = target_shape[0] - src.shape[0]
306
+ return F.pad(src, (0, pad_len), value=0).to(torch.int32)
307
+ else:
308
+ pad_rows = target_shape[0] - src.shape[0]
309
+ pad_cols = target_shape[1] - src.shape[1]
310
+ return F.pad(src, (0, pad_cols, 0, pad_rows),
311
+ value=0).to(torch.int32)
File without changes
@@ -0,0 +1,28 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Mapping
3
+
4
+
5
+ @dataclass
6
+ class ModelConfig():
7
+ max_model_len: int = 2048
8
+ max_prefill_len: int = 1024
9
+ prefill_batch_size: int = 1
10
+ decode_batch_size: int = 1
11
+ block_size: int = 16
12
+ num_layers: int = 32
13
+ num_kv_heads: int = 32
14
+ head_dim: int = 128
15
+ vocab_size: int = 32000
16
+ model: str = "llama3"
17
+ hf_config: str = ""
18
+ architectures: List[str] = field(default_factory=list)
19
+ override_generation_config: dict[str, Any] = field(default_factory=dict)
20
+ hf_overrides: dict[str, Any] = field(default_factory=dict)
21
+
22
+
23
+ @dataclass
24
+ class VllmConfig():
25
+ additional_config: Mapping[str, Any] = field(default_factory=dict)
26
+ # Set default max_model_len to turn off warnings.
27
+ model_config: ModelConfig = field(
28
+ default_factory=lambda: ModelConfig(max_model_len=1024))