tpu-inference 0.11.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (123) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  53. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  54. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  55. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  56. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  57. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  58. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  59. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  60. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  61. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  63. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  65. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  67. tpu_inference/logger.py +10 -0
  68. tpu_inference/lora/__init__.py +0 -0
  69. tpu_inference/lora/torch_lora_ops.py +103 -0
  70. tpu_inference/lora/torch_punica_tpu.py +308 -0
  71. tpu_inference/mock/__init__.py +0 -0
  72. tpu_inference/mock/vllm_config_utils.py +28 -0
  73. tpu_inference/mock/vllm_envs.py +1233 -0
  74. tpu_inference/mock/vllm_logger.py +212 -0
  75. tpu_inference/mock/vllm_logging_utils.py +15 -0
  76. tpu_inference/models/__init__.py +0 -0
  77. tpu_inference/models/jax/__init__.py +0 -0
  78. tpu_inference/models/jax/deepseek_v3.py +868 -0
  79. tpu_inference/models/jax/llama3.py +366 -0
  80. tpu_inference/models/jax/llama4.py +473 -0
  81. tpu_inference/models/jax/llama_eagle3.py +333 -0
  82. tpu_inference/models/jax/phi3.py +376 -0
  83. tpu_inference/models/jax/qwen2.py +375 -0
  84. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  85. tpu_inference/models/jax/qwen3.py +302 -0
  86. tpu_inference/models/jax/utils/__init__.py +0 -0
  87. tpu_inference/models/jax/utils/file_utils.py +96 -0
  88. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  89. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  90. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  91. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  92. tpu_inference/models/vllm/__init__.py +0 -0
  93. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  94. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  95. tpu_inference/platforms/__init__.py +2 -0
  96. tpu_inference/platforms/tpu_jax.py +257 -0
  97. tpu_inference/runner/__init__.py +0 -0
  98. tpu_inference/runner/block_table_jax.py +122 -0
  99. tpu_inference/runner/compilation_manager.py +672 -0
  100. tpu_inference/runner/input_batch_jax.py +435 -0
  101. tpu_inference/runner/kv_cache.py +119 -0
  102. tpu_inference/runner/kv_cache_manager.py +460 -0
  103. tpu_inference/runner/lora_utils.py +92 -0
  104. tpu_inference/runner/multimodal_manager.py +208 -0
  105. tpu_inference/runner/persistent_batch_manager.py +244 -0
  106. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  107. tpu_inference/runner/structured_decoding_manager.py +89 -0
  108. tpu_inference/runner/tpu_jax_runner.py +771 -0
  109. tpu_inference/runner/utils.py +426 -0
  110. tpu_inference/spec_decode/__init__.py +0 -0
  111. tpu_inference/spec_decode/jax/__init__.py +0 -0
  112. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  113. tpu_inference/tpu_info.py +77 -0
  114. tpu_inference/utils.py +294 -0
  115. tpu_inference/worker/__init__.py +0 -0
  116. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  117. tpu_inference/worker/base.py +100 -0
  118. tpu_inference/worker/tpu_worker_jax.py +321 -0
  119. tpu_inference-0.11.1rc1.dist-info/METADATA +101 -0
  120. tpu_inference-0.11.1rc1.dist-info/RECORD +123 -0
  121. tpu_inference-0.11.1rc1.dist-info/WHEEL +5 -0
  122. tpu_inference-0.11.1rc1.dist-info/licenses/LICENSE +201 -0
  123. tpu_inference-0.11.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,47 @@
1
+ """Utility functions for ragged paged attention."""
2
+ import jax
3
+ from jax._src import dtypes
4
+
5
+
6
+ def cdiv(a, b):
7
+ assert b != 0
8
+ return (a + b - 1) // b
9
+
10
+
11
+ def align_to(x, a):
12
+ return cdiv(x, a) * a
13
+
14
+
15
+ def get_dtype_packing(dtype):
16
+ bits = dtypes.bit_width(dtype)
17
+ return 32 // bits
18
+
19
+
20
+ def next_power_of_2(x: int):
21
+ """Finds the smallest power of 2 >= x using bit manipulation.
22
+
23
+ Args:
24
+ x: The input number (should be an integer).
25
+
26
+ Returns:
27
+ The smallest integer power of 2 that is >= x.
28
+ """
29
+ assert x > 0
30
+ if x == 1:
31
+ return 1
32
+ return 1 << (x - 1).bit_length()
33
+
34
+
35
+ def get_tpu_version() -> int:
36
+ """Returns the numeric version of the TPU, or -1 if not on TPU."""
37
+ kind = jax.devices()[0].device_kind
38
+ if 'TPU' not in kind:
39
+ return -1
40
+ if kind.endswith(' lite'):
41
+ kind = kind[:-len(' lite')]
42
+ if kind.endswith('p'):
43
+ kind = kind[:-1]
44
+ if kind == 'TPU7x':
45
+ return 7
46
+ assert kind[:-1] == 'TPU v', kind
47
+ return int(kind[-1])
@@ -0,0 +1,10 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from vllm.logger import _VllmLogger
4
+ from vllm.logger import init_logger as init_vllm_logger
5
+
6
+
7
+ def init_logger(name: str) -> _VllmLogger:
8
+ # Prepend the root "vllm" to the module path to use vllm's configured logger.
9
+ patched_name = "vllm." + name
10
+ return init_vllm_logger(patched_name)
File without changes
@@ -0,0 +1,103 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torchax.interop import call_jax
9
+
10
+
11
+ @jax.jit
12
+ def bgmv_jax(
13
+ inputs, # [num_tokens, hidden_size]
14
+ loras, # [num_loras, lora_rank, hidden_size]
15
+ idxs, # [num_tokens]
16
+ ):
17
+ return jnp.einsum(
18
+ "td,tX,Xld->tl",
19
+ inputs,
20
+ jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
21
+ loras,
22
+ )
23
+
24
+
25
+ def bgmv_torch(
26
+ inputs, # [num_tokens, hidden_size]
27
+ loras, # [num_loras, 1, lora_rank, hidden_size]
28
+ idxs, # [num_tokens]
29
+ ): # [num_tokens, lora_rank]
30
+ # TODO(xiowei): use the below one_hot impl (added in https://github.com/pytorch/xla/pull/9523) after we upgrade torchax version.
31
+ # if len(loras.shape) == 4:
32
+ # loras = loras.squeeze(axis=1)
33
+ # return torch.einsum(
34
+ # "td,tX,Xld->tl",
35
+ # inputs,
36
+ # torch.nn.functional.one_hot(idxs.long(), loras.shape[0]),
37
+ # loras,
38
+ # ) # [num_tokens, lora_rank]
39
+
40
+ if len(loras.shape) == 4:
41
+ loras = loras.squeeze(axis=1)
42
+ return call_jax(bgmv_jax, inputs, loras, idxs)
43
+
44
+
45
+ def bgmv_shrink(
46
+ inputs: torch.Tensor,
47
+ lora_b_weights: torch.Tensor,
48
+ lora_indices_tensor: torch.Tensor,
49
+ scaling: float = 1.0,
50
+ ):
51
+ """
52
+ Args:
53
+ inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
54
+ lora_b_weights (torch.Tensor): LoRA weights of shape
55
+ [max_loras, 1, max_lora_rank, hidden_size].
56
+ output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
57
+ lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
58
+ indicating which LoRA matrix to use for each token.
59
+ scaling (float, optional): Scalar multiplier applied to the output.
60
+ """
61
+ return scaling * bgmv_torch(inputs, lora_b_weights, lora_indices_tensor)
62
+
63
+
64
+ def bgmv_expand_slice(
65
+ inputs: torch.Tensor,
66
+ lora_b_weights: torch.Tensor,
67
+ output_tensor: torch.Tensor,
68
+ lora_indices_tensor: torch.Tensor,
69
+ slice_offset: int,
70
+ slice_size: int,
71
+ add_inputs: bool = True,
72
+ ):
73
+ """
74
+ Args:
75
+ inputs (torch.Tensor): Input tensor of shape [num_tokens, lora_rank].
76
+
77
+ lora_b_weights (torch.Tensor): LoRA weights of shape
78
+ [num_loras, 1, out_features, lora_rank].
79
+
80
+ output_tensor (torch.Tensor): output tensor of shape
81
+ [num_tokens, out_features * num_slices].
82
+
83
+ lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
84
+ indicating which LoRA matrix to use for each token.
85
+ add_inputs (bool): Whether or not to add the input tensor to the output
86
+ tensor.
87
+ """
88
+ outputs = bgmv_torch(inputs, lora_b_weights, lora_indices_tensor)
89
+
90
+ outputs = F.pad(
91
+ outputs,
92
+ (
93
+ slice_offset,
94
+ output_tensor.shape[1] - (slice_offset + slice_size),
95
+ 0,
96
+ 0,
97
+ ),
98
+ )
99
+
100
+ if add_inputs:
101
+ return output_tensor + outputs
102
+ else:
103
+ return outputs
@@ -0,0 +1,308 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import math
5
+ from typing import TYPE_CHECKING, Optional, Union
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchax
10
+ from vllm.lora.punica_wrapper.utils import convert_mapping
11
+
12
+ if TYPE_CHECKING:
13
+ # avoid circuit import
14
+ from vllm.lora.layers import LoRAMapping
15
+
16
+ from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
17
+
18
+ from tpu_inference.lora.torch_lora_ops import bgmv_expand_slice, bgmv_shrink
19
+
20
+
21
+ class PunicaWrapperTPU(PunicaWrapperBase):
22
+ """
23
+ PunicaWrapperTPU is designed to manage and provide metadata for the punica
24
+ kernel. The main function is to maintain the state information for
25
+ Multi-LoRA, and to provide the interface for the pytorch punica ops.
26
+ """
27
+
28
+ def __init__(self, max_num_batched_tokens: int, max_batches: int,
29
+ device: Union[torch.device, str], **kwargs):
30
+ PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
31
+ device)
32
+
33
+ # PunicaWrapperBase defines some tensors with dtype=torch.int64, which
34
+ # isn't supported by the TPU. So convert those tensors to int32.
35
+ # Not all of them are used by the TPU so only convert the useful ones.
36
+ self._token_lora_indices = self._token_lora_indices.to(
37
+ dtype=torch.int32) # map from token to LoRA index.
38
+ self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
39
+ self._sampler_indices_padded = self._sampler_indices_padded.to(
40
+ dtype=torch.int32)
41
+
42
+ def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
43
+ return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
44
+
45
+ @property
46
+ def embeddings_indices(self) -> torch.Tensor:
47
+ """
48
+ This property provides access to the indices used for lora embeddings,
49
+ specifically for VocabParallelEmbeddingWithLoRA.
50
+ """
51
+ raise NotImplementedError(
52
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.embeddings_indices.")
53
+
54
+ @property
55
+ def sampler_indices_padded(self) -> torch.Tensor:
56
+ """
57
+ This property provides access to padded sampler indices.
58
+ """
59
+ raise NotImplementedError(
60
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.sampler_indices_padded.")
61
+
62
+ def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
63
+ x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
64
+ scale: float, **kwargs) -> Optional[torch.Tensor]:
65
+ """
66
+ Performs GEMM for multiple slices of lora_a.
67
+
68
+ Semantics:
69
+ for i in range(len(lora_a_stacked)):
70
+ y[i] += (x @ lora_a_stacked[i]) * scale
71
+
72
+ Args:
73
+ y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors. (n_slices, num_tokens, r)
74
+ x (torch.Tensor): Input tensor. (num_tokens, in_features)
75
+ lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features)
76
+ scale (float): Scaling factor for the operation
77
+ """
78
+ x = x.view(-1, x.shape[-1])
79
+
80
+ for slice_idx in range(len(lora_a_stacked)):
81
+ lora_s = lora_a_stacked[slice_idx]
82
+ y_s = bgmv_shrink(x, lora_s, self._get_token_lora_indices(x),
83
+ scale)
84
+ y[slice_idx, :, :] = y_s # type: ignore[index]
85
+ return y
86
+
87
+ def add_expand(self,
88
+ y: torch.Tensor,
89
+ x: Union[tuple[torch.Tensor, ...], torch.Tensor],
90
+ lora_b_stacked: tuple[torch.Tensor, ...],
91
+ output_slices: tuple[int, ...],
92
+ offset_start: int = 0,
93
+ add_inputs=True,
94
+ **kwargs) -> torch.Tensor:
95
+ """
96
+ Performs GEMM for multiple slices of lora_b.
97
+
98
+ Semantics:
99
+ for i in range(len(lora_b_stacked)):
100
+ slice = output_slices[i]
101
+ y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
102
+ offset += slice
103
+
104
+ Args:
105
+ y (torch.Tensor): Output tensor. (num_tokens, out_features)
106
+ x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors. (n_slices, num_tokens, r)
107
+ lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
108
+ output_slices (tuple[int, ...]): Every slice's size
109
+ add_inputs (bool): Defaults to True.
110
+ """
111
+ y_orig = y
112
+ y = y.view(-1, y.shape[-1])
113
+ offset_left = 0
114
+
115
+ for slice_idx in range(len(lora_b_stacked)):
116
+ y = bgmv_expand_slice(x[slice_idx], lora_b_stacked[slice_idx], y,
117
+ self._get_token_lora_indices(x[slice_idx]),
118
+ offset_left, output_slices[slice_idx],
119
+ add_inputs)
120
+ offset_left += output_slices[slice_idx]
121
+ return y.view(y_orig.shape)
122
+
123
+ def add_lora_embedding(self,
124
+ y: torch.Tensor,
125
+ x: torch.Tensor,
126
+ lora_b_stacked: torch.Tensor,
127
+ add_inputs: bool = True,
128
+ **kwargs) -> torch.Tensor:
129
+ """
130
+ Applies lora specifically for VocabParallelEmbeddingWithLoRA.
131
+
132
+ Semantics:
133
+ y += x @ lora_b_stacked
134
+
135
+ Args:
136
+ y (torch.Tensor): Output tensor.
137
+ x (torch.Tensor): Input tensor.
138
+ lora_b_stacked (torch.Tensor): lora_b's weights.
139
+ add_inputs (bool): Default to True.
140
+ """
141
+ raise NotImplementedError(
142
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_embedding.")
143
+
144
+ def add_lora_linear(self,
145
+ y: torch.Tensor,
146
+ x: torch.Tensor,
147
+ lora_a_stacked: tuple[torch.Tensor, ...],
148
+ lora_b_stacked: tuple[torch.Tensor, ...],
149
+ scale: float,
150
+ output_slices: tuple[int, ...],
151
+ *,
152
+ buffer: Optional[tuple[torch.Tensor, ...]] = None,
153
+ **kwargs) -> torch.Tensor:
154
+ """
155
+ Applicable to linear-related lora.
156
+
157
+ Semantics:
158
+ for i in range(len(lora_a_stacked)):
159
+ y[i] += (
160
+ x[i].unsqueeze(0)
161
+ @ lora_a_stacked[indices[i], layer_idx, :, :]
162
+ @ lora_b_stacked[indices[i], layer_idx, :, :]
163
+ * scale
164
+ ).squeeze(0)
165
+
166
+ Args:
167
+ y (torch.Tensor): Output tensor (bs, out_features). Will not be changed in-place.
168
+ x (torch.Tensor): Input tensor (bs, in_features)
169
+ lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight of length n_slices. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features)
170
+ lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight of length n_slices. lora_b_stacked[i]: (max_loras, 1, out_features, max_lora_rank)
171
+ output_slices (tuple[int, ...]): Every slice's size.
172
+ buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
173
+ """
174
+
175
+ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
176
+
177
+ if buffer is None:
178
+ max_lora_rank = lora_b_stacked[0].size(-1)
179
+ num_tokens = x.size(0)
180
+ buffer = torch.zeros(
181
+ (len(output_slices), num_tokens, max_lora_rank),
182
+ dtype=x.dtype,
183
+ device=x.device,
184
+ )
185
+ buffer = self.add_shrink(
186
+ buffer, x, lora_a_stacked, scale,
187
+ **kwargs) # (n_slices, num_tokens, max_lora_rank)
188
+ return self.add_expand(y,
189
+ buffer,
190
+ lora_b_stacked,
191
+ output_slices,
192
+ add_inputs=True,
193
+ **kwargs)
194
+
195
+ def add_lora_logits(self,
196
+ y: torch.Tensor,
197
+ x: torch.Tensor,
198
+ lora_a_stacked: torch.Tensor,
199
+ lora_b_stacked: torch.Tensor,
200
+ scale,
201
+ *,
202
+ buffer: Optional[torch.Tensor] = None,
203
+ **kwargs) -> torch.Tensor:
204
+ """
205
+ Applies lora specifically for LogitsProcessorWithLoRA.
206
+
207
+ Semantics:
208
+ buffer = (x @ lora_a_stacked) * scale
209
+ y += buffer @ lora_b_stacked
210
+
211
+ Args:
212
+ y (torch.Tensor): Output tensor.
213
+ x (torch.Tensor): Input tensor.
214
+ lora_a_stacked (torch.Tensor): lora_a's weights.
215
+ lora_b_stacked (torch.Tensor):lora_b's weights.
216
+ scale (float): Scaling factor.
217
+ buffer (Optional[torch.Tensor]):Default to None.
218
+ """
219
+ raise NotImplementedError(
220
+ "NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_logits.")
221
+
222
+ @property
223
+ def token_lora_indices(self) -> torch.Tensor:
224
+ """
225
+ This property provides the lora indices corresponding to each token
226
+ in the batch. An index of -1 means no lora should be applied.
227
+ """
228
+ with torchax.default_env():
229
+ token_lora_len = self.indices_len[0]
230
+ return self._token_lora_indices[:token_lora_len]
231
+
232
+ # This performs the same tensor ops as the base method, except it does them
233
+ # on the CPU then transfers the results to the TPU
234
+ def _update_base_metadata(
235
+ self,
236
+ mapping: "LoRAMapping",
237
+ lora_index_to_id: list[Optional[int]],
238
+ max_loras: int,
239
+ vocab_size: int,
240
+ extra_vocab_size: int,
241
+ ):
242
+ # Pad the prompt mapping to avoid running into recompiles on the TPU
243
+ # TODO: Should this happen inside mapping internally? If so how can we
244
+ # avoid having backend specific LoRAMapping classes?
245
+ mapping.prompt_mapping = self._pad_prompt_mapping(
246
+ mapping.prompt_mapping)
247
+
248
+ (
249
+ base_indices,
250
+ sampler_indices,
251
+ sampler_indices_padded,
252
+ embeddings_indices,
253
+ indices_len,
254
+ ) = convert_mapping(
255
+ mapping,
256
+ lora_index_to_id,
257
+ max_loras,
258
+ vocab_size,
259
+ extra_vocab_size,
260
+ "cpu",
261
+ )
262
+ with torchax.default_env():
263
+ self._token_lora_indices = self._pad_to_shape(
264
+ base_indices, self._token_lora_indices.shape,
265
+ dims=1).to(self.device)
266
+ self._sampler_indices = self._pad_to_shape(
267
+ sampler_indices, self._sampler_indices.shape,
268
+ dims=1).to(self.device)
269
+ self._sampler_indices_padded = self._pad_to_shape(
270
+ sampler_indices_padded,
271
+ self._sampler_indices_padded.shape,
272
+ dims=1).to(self.device)
273
+ self._embeddings_indices = self._pad_to_shape(
274
+ embeddings_indices, self._embeddings_indices.shape,
275
+ dims=2).to(self.device)
276
+ self.indices_len[:] = indices_len
277
+
278
+ def _update_prefill_metadata(self,
279
+ token_lora_tensor: torch.Tensor) -> None:
280
+ with torchax.default_env():
281
+ self.batch_size = 1
282
+ self._lora_indices_per_batch[:self.
283
+ batch_size] = token_lora_tensor[:self.
284
+ batch_size]
285
+
286
+ def _pad_prompt_mapping(
287
+ self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
288
+ num_reqs = len(prompt_mapping)
289
+
290
+ # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
291
+ # import
292
+ MIN_NUM_SEQS = 8
293
+
294
+ padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
295
+ pad_len = padded_num_reqs - num_reqs
296
+
297
+ padding = [-1] * pad_len
298
+ return tuple(list(prompt_mapping) + padding)
299
+
300
+ def _pad_to_shape(self, src, target_shape, dims=1):
301
+ if dims == 1:
302
+ pad_len = target_shape[0] - src.shape[0]
303
+ return F.pad(src, (0, pad_len), value=0).to(torch.int32)
304
+ else:
305
+ pad_rows = target_shape[0] - src.shape[0]
306
+ pad_cols = target_shape[1] - src.shape[1]
307
+ return F.pad(src, (0, pad_cols, 0, pad_rows),
308
+ value=0).to(torch.int32)
File without changes
@@ -0,0 +1,28 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Mapping
3
+
4
+
5
+ @dataclass
6
+ class ModelConfig():
7
+ max_model_len: int = 2048
8
+ max_prefill_len: int = 1024
9
+ prefill_batch_size: int = 1
10
+ decode_batch_size: int = 1
11
+ block_size: int = 16
12
+ num_layers: int = 32
13
+ num_kv_heads: int = 32
14
+ head_dim: int = 128
15
+ vocab_size: int = 32000
16
+ model: str = "llama3"
17
+ hf_config: str = ""
18
+ architectures: List[str] = field(default_factory=list)
19
+ override_generation_config: dict[str, Any] = field(default_factory=dict)
20
+ hf_overrides: dict[str, Any] = field(default_factory=dict)
21
+
22
+
23
+ @dataclass
24
+ class VllmConfig():
25
+ additional_config: Mapping[str, Any] = field(default_factory=dict)
26
+ # Set default max_model_len to turn off warnings.
27
+ model_config: ModelConfig = field(
28
+ default_factory=lambda: ModelConfig(max_model_len=1024))