tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,308 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import math
5
+ from typing import TYPE_CHECKING, Optional, Union
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchax
10
+ from vllm.lora.punica_wrapper.utils import convert_mapping
11
+
12
+ if TYPE_CHECKING:
13
+ # avoid circuit import
14
+ from vllm.lora.layers import LoRAMapping
15
+
16
+ from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
17
+
18
+ from tpu_inference.lora.torch_lora_ops import bgmv_expand_slice, bgmv_shrink
19
+
20
+
21
+ class PunicaWrapperTPU(PunicaWrapperBase):
22
+ """
23
+ PunicaWrapperTPU is designed to manage and provide metadata for the punica
24
+ kernel. The main function is to maintain the state information for
25
+ Multi-LoRA, and to provide the interface for the pytorch punica ops.
26
+ """
27
+
28
+ def __init__(self, max_num_batched_tokens: int, max_batches: int,
29
+ device: Union[torch.device, str], **kwargs):
30
+ PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
31
+ device)
32
+
33
+ # PunicaWrapperBase defines some tensors with dtype=torch.int64, which
34
+ # isn't supported by the TPU. So convert those tensors to int32.
35
+ # Not all of them are used by the TPU so only convert the useful ones.
36
+ self._token_lora_indices = self._token_lora_indices.to(
37
+ dtype=torch.int32) # map from token to LoRA index.
38
+ self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
39
+ self._sampler_indices_padded = self._sampler_indices_padded.to(
40
+ dtype=torch.int32)
41
+
42
+ def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
43
+ return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
44
+
45
+ @property
46
+ def embeddings_indices(self) -> torch.Tensor:
47
+ """
48
+ This property provides access to the indices used for lora embeddings,
49
+ specifically for VocabParallelEmbeddingWithLoRA.
50
+ """
51
+ raise NotImplementedError(
52
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.embeddings_indices.")
53
+
54
+ @property
55
+ def sampler_indices_padded(self) -> torch.Tensor:
56
+ """
57
+ This property provides access to padded sampler indices.
58
+ """
59
+ raise NotImplementedError(
60
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.sampler_indices_padded.")
61
+
62
+ def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
63
+ x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
64
+ scale: float, **kwargs) -> Optional[torch.Tensor]:
65
+ """
66
+ Performs GEMM for multiple slices of lora_a.
67
+
68
+ Semantics:
69
+ for i in range(len(lora_a_stacked)):
70
+ y[i] += (x @ lora_a_stacked[i]) * scale
71
+
72
+ Args:
73
+ y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors. (n_slices, num_tokens, r)
74
+ x (torch.Tensor): Input tensor. (num_tokens, in_features)
75
+ lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features)
76
+ scale (float): Scaling factor for the operation
77
+ """
78
+ x = x.view(-1, x.shape[-1])
79
+
80
+ for slice_idx in range(len(lora_a_stacked)):
81
+ lora_s = lora_a_stacked[slice_idx]
82
+ y_s = bgmv_shrink(x, lora_s, self._get_token_lora_indices(x),
83
+ scale)
84
+ y[slice_idx, :, :] = y_s # type: ignore[index]
85
+ return y
86
+
87
+ def add_expand(self,
88
+ y: torch.Tensor,
89
+ x: Union[tuple[torch.Tensor, ...], torch.Tensor],
90
+ lora_b_stacked: tuple[torch.Tensor, ...],
91
+ output_slices: tuple[int, ...],
92
+ offset_start: int = 0,
93
+ add_inputs=True,
94
+ **kwargs) -> torch.Tensor:
95
+ """
96
+ Performs GEMM for multiple slices of lora_b.
97
+
98
+ Semantics:
99
+ for i in range(len(lora_b_stacked)):
100
+ slice = output_slices[i]
101
+ y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
102
+ offset += slice
103
+
104
+ Args:
105
+ y (torch.Tensor): Output tensor. (num_tokens, out_features)
106
+ x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors. (n_slices, num_tokens, r)
107
+ lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
108
+ output_slices (tuple[int, ...]): Every slice's size
109
+ add_inputs (bool): Defaults to True.
110
+ """
111
+ y_orig = y
112
+ y = y.view(-1, y.shape[-1])
113
+ offset_left = 0
114
+
115
+ for slice_idx in range(len(lora_b_stacked)):
116
+ y = bgmv_expand_slice(x[slice_idx], lora_b_stacked[slice_idx], y,
117
+ self._get_token_lora_indices(x[slice_idx]),
118
+ offset_left, output_slices[slice_idx],
119
+ add_inputs)
120
+ offset_left += output_slices[slice_idx]
121
+ return y.view(y_orig.shape)
122
+
123
+ def add_lora_embedding(self,
124
+ y: torch.Tensor,
125
+ x: torch.Tensor,
126
+ lora_b_stacked: torch.Tensor,
127
+ add_inputs: bool = True,
128
+ **kwargs) -> torch.Tensor:
129
+ """
130
+ Applies lora specifically for VocabParallelEmbeddingWithLoRA.
131
+
132
+ Semantics:
133
+ y += x @ lora_b_stacked
134
+
135
+ Args:
136
+ y (torch.Tensor): Output tensor.
137
+ x (torch.Tensor): Input tensor.
138
+ lora_b_stacked (torch.Tensor): lora_b's weights.
139
+ add_inputs (bool): Default to True.
140
+ """
141
+ raise NotImplementedError(
142
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_embedding.")
143
+
144
+ def add_lora_linear(self,
145
+ y: torch.Tensor,
146
+ x: torch.Tensor,
147
+ lora_a_stacked: tuple[torch.Tensor, ...],
148
+ lora_b_stacked: tuple[torch.Tensor, ...],
149
+ scale: float,
150
+ output_slices: tuple[int, ...],
151
+ *,
152
+ buffer: Optional[tuple[torch.Tensor, ...]] = None,
153
+ **kwargs) -> torch.Tensor:
154
+ """
155
+ Applicable to linear-related lora.
156
+
157
+ Semantics:
158
+ for i in range(len(lora_a_stacked)):
159
+ y[i] += (
160
+ x[i].unsqueeze(0)
161
+ @ lora_a_stacked[indices[i], layer_idx, :, :]
162
+ @ lora_b_stacked[indices[i], layer_idx, :, :]
163
+ * scale
164
+ ).squeeze(0)
165
+
166
+ Args:
167
+ y (torch.Tensor): Output tensor (bs, out_features). Will not be changed in-place.
168
+ x (torch.Tensor): Input tensor (bs, in_features)
169
+ lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight of length n_slices. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features)
170
+ lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight of length n_slices. lora_b_stacked[i]: (max_loras, 1, out_features, max_lora_rank)
171
+ output_slices (tuple[int, ...]): Every slice's size.
172
+ buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
173
+ """
174
+
175
+ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
176
+
177
+ if buffer is None:
178
+ max_lora_rank = lora_b_stacked[0].size(-1)
179
+ num_tokens = x.size(0)
180
+ buffer = torch.zeros(
181
+ (len(output_slices), num_tokens, max_lora_rank),
182
+ dtype=x.dtype,
183
+ device=x.device,
184
+ )
185
+ buffer = self.add_shrink(
186
+ buffer, x, lora_a_stacked, scale,
187
+ **kwargs) # (n_slices, num_tokens, max_lora_rank)
188
+ return self.add_expand(y,
189
+ buffer,
190
+ lora_b_stacked,
191
+ output_slices,
192
+ add_inputs=True,
193
+ **kwargs)
194
+
195
+ def add_lora_logits(self,
196
+ y: torch.Tensor,
197
+ x: torch.Tensor,
198
+ lora_a_stacked: torch.Tensor,
199
+ lora_b_stacked: torch.Tensor,
200
+ scale,
201
+ *,
202
+ buffer: Optional[torch.Tensor] = None,
203
+ **kwargs) -> torch.Tensor:
204
+ """
205
+ Applies lora specifically for LogitsProcessorWithLoRA.
206
+
207
+ Semantics:
208
+ buffer = (x @ lora_a_stacked) * scale
209
+ y += buffer @ lora_b_stacked
210
+
211
+ Args:
212
+ y (torch.Tensor): Output tensor.
213
+ x (torch.Tensor): Input tensor.
214
+ lora_a_stacked (torch.Tensor): lora_a's weights.
215
+ lora_b_stacked (torch.Tensor):lora_b's weights.
216
+ scale (float): Scaling factor.
217
+ buffer (Optional[torch.Tensor]):Default to None.
218
+ """
219
+ raise NotImplementedError(
220
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_logits.")
221
+
222
+ @property
223
+ def token_lora_indices(self) -> torch.Tensor:
224
+ """
225
+ This property provides the lora indices corresponding to each token
226
+ in the batch. An index of -1 means no lora should be applied.
227
+ """
228
+ with torchax.default_env():
229
+ token_lora_len = self.indices_len[0]
230
+ return self._token_lora_indices[:token_lora_len]
231
+
232
+ # This performs the same tensor ops as the base method, except it does them
233
+ # on the CPU then transfers the results to the TPU
234
+ def _update_base_metadata(
235
+ self,
236
+ mapping: "LoRAMapping",
237
+ lora_index_to_id: list[Optional[int]],
238
+ max_loras: int,
239
+ vocab_size: int,
240
+ extra_vocab_size: int,
241
+ ):
242
+ # Pad the prompt mapping to avoid running into recompiles on the TPU
243
+ # TODO: Should this happen inside mapping internally? If so how can we
244
+ # avoid having backend specific LoRAMapping classes?
245
+ mapping.prompt_mapping = self._pad_prompt_mapping(
246
+ mapping.prompt_mapping)
247
+
248
+ (
249
+ base_indices,
250
+ sampler_indices,
251
+ sampler_indices_padded,
252
+ embeddings_indices,
253
+ indices_len,
254
+ ) = convert_mapping(
255
+ mapping,
256
+ lora_index_to_id,
257
+ max_loras,
258
+ vocab_size,
259
+ extra_vocab_size,
260
+ "cpu",
261
+ )
262
+ with torchax.default_env():
263
+ self._token_lora_indices = self._pad_to_shape(
264
+ base_indices, self._token_lora_indices.shape,
265
+ dims=1).to(self.device)
266
+ self._sampler_indices = self._pad_to_shape(
267
+ sampler_indices, self._sampler_indices.shape,
268
+ dims=1).to(self.device)
269
+ self._sampler_indices_padded = self._pad_to_shape(
270
+ sampler_indices_padded,
271
+ self._sampler_indices_padded.shape,
272
+ dims=1).to(self.device)
273
+ self._embeddings_indices = self._pad_to_shape(
274
+ embeddings_indices, self._embeddings_indices.shape,
275
+ dims=2).to(self.device)
276
+ self.indices_len[:] = indices_len
277
+
278
+ def _update_prefill_metadata(self,
279
+ token_lora_tensor: torch.Tensor) -> None:
280
+ with torchax.default_env():
281
+ self.batch_size = 1
282
+ self._lora_indices_per_batch[:self.
283
+ batch_size] = token_lora_tensor[:self.
284
+ batch_size]
285
+
286
+ def _pad_prompt_mapping(
287
+ self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
288
+ num_reqs = len(prompt_mapping)
289
+
290
+ # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
291
+ # import
292
+ MIN_NUM_SEQS = 8
293
+
294
+ padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
295
+ pad_len = padded_num_reqs - num_reqs
296
+
297
+ padding = [-1] * pad_len
298
+ return tuple(list(prompt_mapping) + padding)
299
+
300
+ def _pad_to_shape(self, src, target_shape, dims=1):
301
+ if dims == 1:
302
+ pad_len = target_shape[0] - src.shape[0]
303
+ return F.pad(src, (0, pad_len), value=0).to(torch.int32)
304
+ else:
305
+ pad_rows = target_shape[0] - src.shape[0]
306
+ pad_cols = target_shape[1] - src.shape[1]
307
+ return F.pad(src, (0, pad_cols, 0, pad_rows),
308
+ value=0).to(torch.int32)
File without changes
@@ -0,0 +1,28 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Mapping
3
+
4
+
5
+ @dataclass
6
+ class ModelConfig():
7
+ max_model_len: int = 2048
8
+ max_prefill_len: int = 1024
9
+ prefill_batch_size: int = 1
10
+ decode_batch_size: int = 1
11
+ block_size: int = 16
12
+ num_layers: int = 32
13
+ num_kv_heads: int = 32
14
+ head_dim: int = 128
15
+ vocab_size: int = 32000
16
+ model: str = "llama3"
17
+ hf_config: str = ""
18
+ architectures: List[str] = field(default_factory=list)
19
+ override_generation_config: dict[str, Any] = field(default_factory=dict)
20
+ hf_overrides: dict[str, Any] = field(default_factory=dict)
21
+
22
+
23
+ @dataclass
24
+ class VllmConfig():
25
+ additional_config: Mapping[str, Any] = field(default_factory=dict)
26
+ # Set default max_model_len to turn off warnings.
27
+ model_config: ModelConfig = field(
28
+ default_factory=lambda: ModelConfig(max_model_len=1024))