tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,460 @@
1
+ import functools
2
+ from typing import TYPE_CHECKING, Dict, List
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from jax.sharding import NamedSharding, PartitionSpec
7
+ from torchax.ops.mappings import t2j_dtype
8
+ from vllm.attention import Attention
9
+ from vllm.attention.backends.abstract import AttentionType
10
+ from vllm.config import get_layers_from_vllm_config
11
+ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
12
+ KVCacheSpec, MLAAttentionSpec,
13
+ SlidingWindowSpec)
14
+
15
+ from tpu_inference import utils
16
+ from tpu_inference import utils as common_utils
17
+ from tpu_inference.logger import init_logger
18
+ from tpu_inference.runner import utils as runner_utils
19
+ from tpu_inference.runner.input_batch_jax import CachedRequestState, InputBatch
20
+ from tpu_inference.runner.kv_cache import create_kv_caches
21
+
22
+ if TYPE_CHECKING:
23
+ from vllm.v1.request import Request
24
+
25
+ from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
26
+
27
+ logger = init_logger(__name__)
28
+
29
+
30
+ class KVCacheManager:
31
+
32
+ def __init__(self, runner: "TPUModelRunner"):
33
+ self.runner = runner
34
+ # Layer pairings for cross-layer KV sharing.
35
+ # If an Attention layer `layer_name` is in the keys of this dict, it
36
+ # means this layer will perform attention using the keys and values
37
+ # from the KV cache of `shared_kv_cache_layers[layer_name]`.
38
+ self.shared_kv_cache_layers: dict[str, str] = {}
39
+
40
+ def get_kv_cache_spec(self):
41
+ # TODO(xiang): this hack tricks engine core to init successfully
42
+ block_size = self.runner.cache_config.block_size
43
+ use_mla = self.runner.model_config.use_mla
44
+ kv_cache_spec: dict[str, KVCacheSpec] = {}
45
+
46
+ # If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
47
+ # attention into compilation config.
48
+ # Use FullAttentionSpec for each layer
49
+ # TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
50
+ if len(self.runner.vllm_config.compilation_config.
51
+ static_forward_context) == 0:
52
+ model_config = self.runner.model_config
53
+ parallel_config = self.runner.parallel_config
54
+ # Pad num_kv_heads to multiple of TP size.
55
+ num_kv_heads = common_utils.get_padded_num_heads(
56
+ model_config.get_total_num_kv_heads(),
57
+ self.runner.mesh.shape["model"])
58
+ head_size = common_utils.get_padded_head_dim(
59
+ model_config.get_head_size())
60
+ for i in range(model_config.get_num_layers(parallel_config)):
61
+ if use_mla:
62
+ kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
63
+ block_size=block_size,
64
+ num_kv_heads=num_kv_heads,
65
+ head_size=head_size,
66
+ dtype=self.runner.kv_cache_dtype,
67
+ cache_dtype_str=self.runner.vllm_config.cache_config.
68
+ cache_dtype)
69
+ else:
70
+ kv_cache_spec[f"layer.{i}"] = FullAttentionSpec(
71
+ block_size=block_size,
72
+ num_kv_heads=num_kv_heads,
73
+ head_size=head_size,
74
+ dtype=self.runner.kv_cache_dtype)
75
+ if self.runner.speculative_config and self.runner.speculative_config.method == "eagle3":
76
+ draft_model_config = self.runner.speculative_config.draft_model_config
77
+ hf_config = draft_model_config.hf_config
78
+ num_kv_heads = common_utils.get_padded_num_heads(
79
+ hf_config.num_key_value_heads,
80
+ self.runner.mesh.shape["model"])
81
+ head_size = common_utils.get_padded_head_dim(
82
+ hf_config.hidden_size // hf_config.num_attention_heads)
83
+
84
+ # Eagle3 has only 1 layer
85
+ for i in range(1):
86
+ if use_mla:
87
+ kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
88
+ block_size=block_size,
89
+ num_kv_heads=num_kv_heads,
90
+ head_size=head_size,
91
+ dtype=self.runner.kv_cache_dtype,
92
+ cache_dtype_str=self.runner.vllm_config.
93
+ cache_config.cache_dtype)
94
+ else:
95
+ kv_cache_spec[f"draft_layer.{i}"] = FullAttentionSpec(
96
+ block_size=block_size,
97
+ num_kv_heads=num_kv_heads,
98
+ head_size=head_size,
99
+ dtype=self.runner.kv_cache_dtype)
100
+ else:
101
+ # Else propagate attention modules from compilation config.
102
+ layers = get_layers_from_vllm_config(self.runner.vllm_config,
103
+ Attention)
104
+ for layer_name, attn_module in layers.items():
105
+ if (kv_tgt_layer :=
106
+ attn_module.kv_sharing_target_layer_name) is not None:
107
+ # The layer doesn't need its own KV cache and will use that of
108
+ # the target layer. We skip creating a KVCacheSpec for it, so
109
+ # that KV cache management logic will act as this layer does
110
+ # not exist, and doesn't allocate KV cache for the layer. This
111
+ # enables the memory saving of cross-layer kv sharing, allowing
112
+ # a given amount of memory to accommodate longer context lengths
113
+ # or enable more requests to be processed simultaneously.
114
+ self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
115
+ continue
116
+ if attn_module.attn_type == AttentionType.DECODER:
117
+ if attn_module.sliding_window is not None:
118
+ kv_cache_spec[layer_name] = SlidingWindowSpec(
119
+ block_size=block_size,
120
+ num_kv_heads=common_utils.get_padded_num_heads(
121
+ attn_module.num_kv_heads,
122
+ self.runner.mesh.shape["model"]),
123
+ head_size=common_utils.get_padded_head_dim(
124
+ attn_module.head_size),
125
+ dtype=self.runner.kv_cache_dtype,
126
+ sliding_window=attn_module.sliding_window)
127
+ elif use_mla:
128
+ kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
129
+ block_size=block_size,
130
+ num_kv_heads=attn_module.num_kv_heads,
131
+ head_size=attn_module.head_size,
132
+ dtype=self.runner.kv_cache_dtype,
133
+ cache_dtype_str=self.runner.vllm_config.
134
+ cache_config.cache_dtype)
135
+ else:
136
+ kv_cache_spec[layer_name] = FullAttentionSpec(
137
+ block_size=block_size,
138
+ num_kv_heads=common_utils.get_padded_num_heads(
139
+ attn_module.num_kv_heads,
140
+ self.runner.mesh.shape["model"]),
141
+ head_size=common_utils.get_padded_head_dim(
142
+ attn_module.head_size),
143
+ dtype=self.runner.kv_cache_dtype)
144
+ elif attn_module.attn_type in (AttentionType.ENCODER,
145
+ AttentionType.ENCODER_ONLY):
146
+ # encoder-only attention does not need KV cache.
147
+ continue
148
+ elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
149
+ raise NotImplementedError
150
+ else:
151
+ raise ValueError(
152
+ f"Unknown attention type: {attn_module.attn_type}")
153
+ return kv_cache_spec
154
+
155
+ def maybe_reinitialize_input_batch(self,
156
+ kv_cache_config: KVCacheConfig) -> None:
157
+ block_sizes = [
158
+ kv_cache_group.kv_cache_spec.block_size
159
+ for kv_cache_group in kv_cache_config.kv_cache_groups
160
+ ]
161
+ if block_sizes != [self.runner.cache_config.block_size]:
162
+ assert self.runner.cache_config.cpu_offload_gb == 0, (
163
+ "Cannot re-initialize the input batch when CPU weight "
164
+ "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
165
+ "for more details.")
166
+ new_input_batch = InputBatch(
167
+ max_num_reqs=self.runner.max_num_reqs,
168
+ max_model_len=self.runner.max_model_len,
169
+ max_num_batched_tokens=self.runner.max_num_tokens,
170
+ pin_memory=False,
171
+ vocab_size=self.runner.model_config.get_vocab_size(),
172
+ block_sizes=block_sizes,
173
+ )
174
+ self.runner.input_batch = new_input_batch
175
+ self.runner.persistent_batch_manager.input_batch = new_input_batch
176
+
177
+ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
178
+ self.maybe_reinitialize_input_batch(kv_cache_config)
179
+
180
+ # uniform page size.
181
+ representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
182
+ page_size_bytes = representative_spec.page_size_bytes
183
+ self.runner.layer_name_to_kvcache_index: Dict[str, int] = {}
184
+ kv_caches = self.runner.kv_caches
185
+ num_blocks_list = []
186
+ for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
187
+ assert kv_cache_tensor.size % page_size_bytes == 0
188
+ num_blocks = kv_cache_tensor.size // page_size_bytes
189
+ # NOTE: we'll multiply the num_kv_heads by 2 in the function
190
+ kv_cache = create_kv_caches(
191
+ num_blocks=num_blocks,
192
+ block_size=representative_spec.block_size,
193
+ num_kv_heads=representative_spec.num_kv_heads,
194
+ head_size=representative_spec.head_size,
195
+ mesh=self.runner.mesh,
196
+ layer_names=[f'kv_cache_tensor.{i}'],
197
+ cache_dtype=t2j_dtype(representative_spec.dtype),
198
+ )[0]
199
+ kv_caches.append(kv_cache)
200
+ num_blocks_list.append(num_blocks)
201
+ for layer_name in kv_cache_tensor.shared_by:
202
+ self.runner.layer_name_to_kvcache_index[layer_name] = i
203
+
204
+ if self.shared_kv_cache_layers:
205
+ for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
206
+ ):
207
+ self.runner.layer_name_to_kvcache_index[
208
+ layer_name] = self.runner.layer_name_to_kvcache_index[
209
+ target_layer_name]
210
+
211
+ logger.info(
212
+ f"Init kv-cache | "
213
+ f"num_layers={len(kv_caches)} | "
214
+ f"shape=(num_blocks, {kv_caches[0].shape[1:]}) | "
215
+ f"num_blocks={num_blocks_list} | "
216
+ f"sharding={kv_caches[0].sharding} | "
217
+ f"dtype={kv_caches[0].dtype} | "
218
+ f"hbm={utils.hbm_usage_gb(self.runner.mesh.devices.flatten())}Gb")
219
+
220
+ @staticmethod
221
+ @functools.partial(jax.jit)
222
+ def _jitted_gather_kv_cache(kv_caches: List[jax.Array],
223
+ block_ids: jax.Array) -> List[jax.Array]:
224
+ """
225
+ JIT-compiled function to gather KV cache slices for all layers at once.
226
+ This uses jax.tree.map to apply the operation across all layers.
227
+ """
228
+
229
+ def gather_and_reshape(layer_kv_cache):
230
+ return layer_kv_cache.at[block_ids].get().reshape(
231
+ -1, *layer_kv_cache.shape[2:])
232
+
233
+ return jax.tree.map(gather_and_reshape, kv_caches)
234
+
235
+ @staticmethod
236
+ @functools.partial(
237
+ jax.jit,
238
+ static_argnames=("len_block"),
239
+ )
240
+ def _jitted_gather_continuous_kv_cache(kv_caches: List[jax.Array],
241
+ start_block,
242
+ len_block) -> List[jax.Array]:
243
+ """
244
+ JIT-compiled function to gather KV cache slices for all layers at once.
245
+ This uses jax.tree.map to apply the operation across all layers.
246
+ """
247
+
248
+ def gather_and_reshape(layer_kv_cache):
249
+ shape = layer_kv_cache.shape
250
+ return jax.lax.dynamic_slice_in_dim(layer_kv_cache,
251
+ start_block,
252
+ len_block,
253
+ axis=0).reshape(
254
+ -1, *shape[2:])
255
+
256
+ return jax.tree.map(gather_and_reshape, kv_caches)
257
+
258
+ @staticmethod
259
+ @functools.partial(
260
+ jax.jit,
261
+ static_argnames=("block_size"),
262
+ donate_argnames=(
263
+ "kv_caches",
264
+ "kv_cache_slices",
265
+ ),
266
+ )
267
+ def _jitted_insert_kv_cache(
268
+ block_size,
269
+ kv_caches: List[jax.Array],
270
+ kv_cache_slices: List[jax.Array],
271
+ block_numbers: jax.Array,
272
+ ) -> List[jax.Array]:
273
+ """
274
+ JIT-compiled function to insert KV cache slices into the physical
275
+ cache for all layers at once. This fuses the pad, reshape, and scatter
276
+ operations into a single efficient kernel.
277
+ """
278
+
279
+ def _update_layer(cache, slices):
280
+ """The function to apply to each layer's cache and slices."""
281
+ reshaped_slices = slices.reshape(-1, 1, block_size,
282
+ *slices.shape[1:])
283
+ for (i, block_idx) in enumerate(block_numbers):
284
+ cache = jax.lax.dynamic_update_slice_in_dim(cache,
285
+ reshaped_slices[i],
286
+ block_idx,
287
+ axis=0)
288
+ return cache
289
+
290
+ return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
291
+
292
+ @staticmethod
293
+ @functools.partial(
294
+ jax.jit,
295
+ static_argnames=("block_size"),
296
+ donate_argnames=(
297
+ "kv_caches",
298
+ "kv_cache_slices",
299
+ ),
300
+ )
301
+ def _jitted_insert_continuous_kv_cache(
302
+ block_size,
303
+ kv_caches: List[jax.Array],
304
+ kv_cache_slices: List[jax.Array],
305
+ start_block,
306
+ ) -> List[jax.Array]:
307
+ """
308
+ JIT-compiled function to insert KV cache slices into continuous blocks.
309
+ Makes use of dynamic_update_slice_in_dim.
310
+ """
311
+
312
+ def _update_layer(cache, slices):
313
+ """The function to apply to each layer's cache and slices."""
314
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
315
+
316
+ return jax.lax.dynamic_update_slice_in_dim(cache,
317
+ reshaped_slices,
318
+ start_block,
319
+ axis=0)
320
+
321
+ return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
322
+
323
+ def get_kv_cache_for_block_ids(
324
+ self,
325
+ block_ids: List[int],
326
+ ) -> List[jax.Array]:
327
+ """
328
+ Extracts the KV cache slices for a given list of block IDs.
329
+ This assumes all provided blocks are full.
330
+
331
+ Args:
332
+ block_ids: A list of block IDs to extract KV cache for.
333
+
334
+ Returns:
335
+ A list of JAX arrays, with each array representing the KV cache
336
+ slices for a layer, concatenated for all blocks.
337
+ """
338
+ if block_ids == list(range(block_ids[0],
339
+ block_ids[0] + len(block_ids))):
340
+ with runner_utils.LatencyTracker(
341
+ "BatchedGatherKVSlices-for-blocks"):
342
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
343
+ self.runner.kv_caches, block_ids[0], len(block_ids))
344
+
345
+ else:
346
+ with runner_utils.LatencyTracker(
347
+ "BatchedGatherKVSlices-for-blocks"):
348
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
349
+ self.runner.kv_caches, jnp.array(block_ids))
350
+ return batched_kv_cache_per_layer
351
+
352
+ def transfer_kv_cache(self,
353
+ kv_cache_slices: List[jax.Array]) -> List[jax.Array]:
354
+ """
355
+ Transfers KV cache slices to the runner's mesh.
356
+
357
+ This is used when a KV cache generated on one runner (e.g., a prefill
358
+ runner) needs to be used on another runner (e.g., a decode runner)
359
+ with a different device mesh. The transfer is asynchronous.
360
+
361
+ Args:
362
+ kv_cache_slices: A list of JAX arrays, where each array contains
363
+ the KV cache slices for a specific layer. The shape of each
364
+ slice is expected to be (num_tokens, num_kv_heads * 2, head_size).
365
+
366
+ Returns:
367
+ A new list of JAX arrays representing the KV cache slices, sharded
368
+ across the runner's device mesh.
369
+ """
370
+ # The KV cache slices have a shape of (num_tokens, num_kv_heads * 2, head_size).
371
+ # We shard along the num_kv_heads dimension (axis=1), which corresponds
372
+ # to the "model" axis of the mesh for tensor parallelism.
373
+ logger.debug(
374
+ f"Transferring kv cache shape {len(kv_cache_slices)} * {kv_cache_slices[0].shape} sharding {kv_cache_slices[0].sharding} size {kv_cache_slices[0].nbytes * len(kv_cache_slices)/1024/1024} Mbytes"
375
+ )
376
+ sharding = NamedSharding(self.runner.mesh,
377
+ PartitionSpec(None, "model"))
378
+ transferred_kv_cache = jax.device_put(kv_cache_slices, sharding)
379
+ for cache in transferred_kv_cache:
380
+ cache.block_until_ready()
381
+ return transferred_kv_cache
382
+
383
+ def insert_request_with_kv_cache(
384
+ self,
385
+ request: "Request",
386
+ kv_cache_slices: List[jax.Array],
387
+ block_ids: List[List[int]],
388
+ ):
389
+ """
390
+ Inserts a request and its KV cache into the runner. This is used to
391
+ transfer a request from a prefill runner to a decode runner.
392
+
393
+ The provided KV cache slices are copied into the physical blocks
394
+ allocated for the request. The runner's internal state is then updated
395
+ to include the request.
396
+
397
+ Args:
398
+ request: The vLLM request object, containing the state after prefill.
399
+ kv_cache_slices: The KV cache for the request, already transferred
400
+ to this runner's mesh. This is a list of JAX arrays, one per layer.
401
+ block_ids: The physical block numbers allocated for this request on
402
+ this runner. This is a list of lists, for each KV cache group.
403
+ """
404
+ # Assume one KV cache group for now, which is consistent with current setup.
405
+ if len(block_ids) > 1:
406
+ raise NotImplementedError(
407
+ "Inserting KV cache for models with multiple KV cache groups "
408
+ "is not supported yet.")
409
+ block_numbers = block_ids[0]
410
+ if block_numbers == list(
411
+ range(block_numbers[0],
412
+ block_numbers[0] + len(block_numbers))):
413
+ # For continuous blocks we use slice instead of scatter.
414
+ start_block = block_numbers[0]
415
+ with runner_utils.LatencyTracker(
416
+ f"JittedInsertContinuousKVCache-b{len(block_numbers)}"):
417
+ logger.debug(f"inserting to continuous blocks {block_numbers}")
418
+ self.runner.kv_caches = KVCacheManager._jitted_insert_continuous_kv_cache(
419
+ self.runner.block_size,
420
+ self.runner.kv_caches,
421
+ kv_cache_slices,
422
+ start_block,
423
+ )
424
+ else:
425
+ with runner_utils.LatencyTracker(
426
+ f"JittedInsertKVCache-b{len(block_numbers)}"):
427
+ logger.debug(
428
+ f"inserting to non continuous blocks {block_numbers}")
429
+ self.runner.kv_caches = KVCacheManager._jitted_insert_kv_cache(
430
+ self.runner.block_size,
431
+ self.runner.kv_caches,
432
+ kv_cache_slices,
433
+ jnp.array(block_numbers),
434
+ )
435
+
436
+ logger.debug(
437
+ f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
438
+
439
+ # Update runner's internal state to track the new request.
440
+ req_id = request.request_id
441
+ if req_id in self.runner.requests:
442
+ logger.warning(
443
+ f"Request {req_id} already exists in the runner. Overwriting.")
444
+
445
+ # Create a CachedRequestState object to add to the input batch.
446
+ req_state = CachedRequestState(
447
+ req_id=request.request_id,
448
+ prompt_token_ids=request.prompt_token_ids,
449
+ output_token_ids=[request.all_token_ids[-1]],
450
+ sampling_params=request.sampling_params,
451
+ block_ids=tuple(block_ids),
452
+ num_computed_tokens=request.num_computed_tokens,
453
+ lora_request=request.lora_request,
454
+ mm_features=getattr(request, "mm_features", []),
455
+ pooling_params=getattr(request, "pooling_params", None),
456
+ generator=None,
457
+ )
458
+
459
+ self.runner.requests[req_id] = req_state
460
+ self.runner.input_batch.add_request(req_state)
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ from torchax.interop import jax_view
7
+ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
8
+ from vllm.lora.request import LoRARequest
9
+
10
+ from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
11
+
12
+ if TYPE_CHECKING:
13
+ from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
14
+
15
+
16
+ class LoraUtils:
17
+
18
+ def __init__(self, runner: "TPUModelRunner"):
19
+ self.runner = runner
20
+
21
+ def set_active_loras(self, num_scheduled_tokens_per_req,
22
+ total_num_scheduled_tokens,
23
+ padded_total_num_scheduled_tokens):
24
+ # We need to respect padding when activating LoRA adapters
25
+ padded_num_scheduled_tokens_per_req = np.copy(
26
+ num_scheduled_tokens_per_req
27
+ ) # Copying to avoid accidental state corruption bugs
28
+ padded_num_scheduled_tokens_per_req[-1] += \
29
+ padded_total_num_scheduled_tokens - total_num_scheduled_tokens
30
+
31
+ prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
32
+ token_lora_mapping: tuple[int,
33
+ ...] # of size np.sum(num_scheduled_tokens)
34
+ lora_requests: set[LoRARequest]
35
+ prompt_lora_mapping, token_lora_mapping, lora_requests = \
36
+ self.runner.input_batch.make_lora_inputs(padded_num_scheduled_tokens_per_req)
37
+ # One should not put lora_manager.set_active_loras under
38
+ # torchax.default_env() because set_active_loras also load lora from
39
+ # disk and torchax currently does not support that. Here we load the
40
+ # lora and set the lora weight to the linear layers.
41
+ self.runner._set_active_loras(prompt_lora_mapping, token_lora_mapping,
42
+ lora_requests)
43
+
44
+ params_and_buffers = shard_model_to_tpu(self.runner.model.model,
45
+ self.runner.mesh)
46
+ self.runner.state = jax_view(params_and_buffers)
47
+
48
+ def extract_lora_metadata(self):
49
+ if self.runner.lora_config is None:
50
+ return None
51
+
52
+ metadata = {}
53
+ punica_wrapper = None
54
+ for _, m in self.runner.model.model.named_modules():
55
+ if isinstance(m, BaseLinearLayerWithLoRA):
56
+ assert getattr(
57
+ m, 'punica_wrapper', None
58
+ ) is not None, 'A lora wrapper should have contained a punica_wrapper'
59
+ punica_wrapper = m.punica_wrapper
60
+ break
61
+ assert punica_wrapper is not None, 'Should have been able to find a punica wrapper from the Lora wrapper.'
62
+
63
+ # vars does not show inherited methods or class attributes but this is
64
+ # fine because we only care about instance attributes.
65
+ for k in vars(punica_wrapper):
66
+ v = getattr(punica_wrapper, k, None)
67
+ if k == 'device': # Exclude string as it can't be traced by jax.jit
68
+ continue
69
+ metadata[k] = v
70
+ return jax_view(metadata)
71
+
72
+
73
+ def replace_lora_metadata(model, metadata: dict, lora_config) -> dict:
74
+ if lora_config is None or not metadata:
75
+ return {}
76
+
77
+ original_metadata = {}
78
+ punica_wrapper = None
79
+ for _, m in model.named_modules():
80
+ if isinstance(m, BaseLinearLayerWithLoRA):
81
+ assert getattr(
82
+ m, 'punica_wrapper', None
83
+ ) is not None, 'A lora wrapper should have contained a punica_wrapper'
84
+ punica_wrapper = m.punica_wrapper
85
+ break
86
+ assert punica_wrapper is not None, 'Should have been able to find a punica wrapper from the Lora wrapper.'
87
+
88
+ for k in vars(punica_wrapper):
89
+ if k in metadata:
90
+ original_metadata[k] = getattr(punica_wrapper, k)
91
+ setattr(punica_wrapper, k, metadata[k])
92
+ return original_metadata