tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,478 @@
1
+ import functools
2
+ from typing import TYPE_CHECKING, Dict, List
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+ import vllm.envs as envs
8
+ from jax.sharding import NamedSharding, PartitionSpec
9
+ from torchax.ops.mappings import t2j_dtype
10
+ from vllm.attention.backends.abstract import AttentionType
11
+ from vllm.attention.layer import Attention
12
+ from vllm.config import get_layers_from_vllm_config
13
+ from vllm.utils.math_utils import cdiv
14
+ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
15
+ KVCacheSpec, MLAAttentionSpec,
16
+ SlidingWindowSpec)
17
+
18
+ from tpu_inference import utils
19
+ from tpu_inference import utils as common_utils
20
+ from tpu_inference.logger import init_logger
21
+ from tpu_inference.runner import utils as runner_utils
22
+ from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
23
+ from tpu_inference.runner.kv_cache import create_kv_caches
24
+
25
+ if TYPE_CHECKING:
26
+ from vllm.v1.request import Request
27
+
28
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
29
+
30
+ logger = init_logger(__name__)
31
+
32
+
33
+ class KVCacheManager:
34
+
35
+ def __init__(self, runner: "TPUModelRunner"):
36
+ self.runner = runner
37
+ # Layer pairings for cross-layer KV sharing.
38
+ # If an Attention layer `layer_name` is in the keys of this dict, it
39
+ # means this layer will perform attention using the keys and values
40
+ # from the KV cache of `shared_kv_cache_layers[layer_name]`.
41
+ self.shared_kv_cache_layers: dict[str, str] = {}
42
+
43
+ def get_kv_cache_spec(self):
44
+ # TODO(xiang): this hack tricks engine core to init successfully
45
+ block_size = self.runner.cache_config.block_size
46
+ use_mla = self.runner.model_config.use_mla
47
+ kv_cache_spec: dict[str, KVCacheSpec] = {}
48
+
49
+ # If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
50
+ # attention into compilation config.
51
+ # Use FullAttentionSpec for each layer
52
+ # TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
53
+ if len(self.runner.vllm_config.compilation_config.
54
+ static_forward_context) == 0:
55
+ model_config = self.runner.model_config
56
+ parallel_config = self.runner.parallel_config
57
+ # Pad num_kv_heads to multiple of TP size.
58
+ num_kv_heads = common_utils.get_padded_num_heads(
59
+ model_config.get_total_num_kv_heads(),
60
+ self.runner.mesh.shape["model"])
61
+ head_size = common_utils.get_padded_head_dim(
62
+ model_config.get_head_size())
63
+ for i in range(model_config.get_num_layers(parallel_config)):
64
+ if use_mla:
65
+ kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
66
+ block_size=block_size,
67
+ num_kv_heads=num_kv_heads,
68
+ head_size=head_size,
69
+ dtype=self.runner.kv_cache_dtype,
70
+ cache_dtype_str=self.runner.vllm_config.cache_config.
71
+ cache_dtype)
72
+ else:
73
+ kv_cache_spec[f"layer.{i}"] = FullAttentionSpec(
74
+ block_size=block_size,
75
+ num_kv_heads=num_kv_heads,
76
+ head_size=head_size,
77
+ dtype=self.runner.kv_cache_dtype)
78
+ if self.runner.speculative_config and self.runner.speculative_config.method == "eagle3":
79
+ draft_model_config = self.runner.speculative_config.draft_model_config
80
+ hf_config = draft_model_config.hf_config
81
+ num_kv_heads = common_utils.get_padded_num_heads(
82
+ hf_config.num_key_value_heads,
83
+ self.runner.mesh.shape["model"])
84
+ head_size = common_utils.get_padded_head_dim(
85
+ hf_config.hidden_size // hf_config.num_attention_heads)
86
+
87
+ # Eagle3 has only 1 layer
88
+ for i in range(1):
89
+ if use_mla:
90
+ kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
91
+ block_size=block_size,
92
+ num_kv_heads=num_kv_heads,
93
+ head_size=head_size,
94
+ dtype=self.runner.kv_cache_dtype,
95
+ cache_dtype_str=self.runner.vllm_config.
96
+ cache_config.cache_dtype)
97
+ else:
98
+ kv_cache_spec[f"draft_layer.{i}"] = FullAttentionSpec(
99
+ block_size=block_size,
100
+ num_kv_heads=num_kv_heads,
101
+ head_size=head_size,
102
+ dtype=self.runner.kv_cache_dtype)
103
+ else:
104
+ # Else propagate attention modules from compilation config.
105
+ layers = get_layers_from_vllm_config(self.runner.vllm_config,
106
+ Attention)
107
+ for layer_name, attn_module in layers.items():
108
+ if (kv_tgt_layer :=
109
+ attn_module.kv_sharing_target_layer_name) is not None:
110
+ # The layer doesn't need its own KV cache and will use that of
111
+ # the target layer. We skip creating a KVCacheSpec for it, so
112
+ # that KV cache management logic will act as this layer does
113
+ # not exist, and doesn't allocate KV cache for the layer. This
114
+ # enables the memory saving of cross-layer kv sharing, allowing
115
+ # a given amount of memory to accommodate longer context lengths
116
+ # or enable more requests to be processed simultaneously.
117
+ self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
118
+ continue
119
+ if attn_module.attn_type == AttentionType.DECODER:
120
+ if attn_module.sliding_window is not None:
121
+ kv_cache_spec[layer_name] = SlidingWindowSpec(
122
+ block_size=block_size,
123
+ num_kv_heads=common_utils.get_padded_num_heads(
124
+ attn_module.num_kv_heads,
125
+ self.runner.mesh.shape["model"]),
126
+ head_size=common_utils.get_padded_head_dim(
127
+ attn_module.head_size),
128
+ dtype=self.runner.kv_cache_dtype,
129
+ sliding_window=attn_module.sliding_window)
130
+ elif use_mla:
131
+ kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
132
+ block_size=block_size,
133
+ num_kv_heads=attn_module.num_kv_heads,
134
+ head_size=attn_module.head_size,
135
+ dtype=self.runner.kv_cache_dtype,
136
+ cache_dtype_str=self.runner.vllm_config.
137
+ cache_config.cache_dtype)
138
+ else:
139
+ kv_cache_spec[layer_name] = FullAttentionSpec(
140
+ block_size=block_size,
141
+ num_kv_heads=common_utils.get_padded_num_heads(
142
+ attn_module.num_kv_heads,
143
+ self.runner.mesh.shape["model"]),
144
+ head_size=common_utils.get_padded_head_dim(
145
+ attn_module.head_size),
146
+ dtype=self.runner.kv_cache_dtype)
147
+ elif attn_module.attn_type in (AttentionType.ENCODER,
148
+ AttentionType.ENCODER_ONLY):
149
+ # encoder-only attention does not need KV cache.
150
+ continue
151
+ elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
152
+ raise NotImplementedError
153
+ else:
154
+ raise ValueError(
155
+ f"Unknown attention type: {attn_module.attn_type}")
156
+ return kv_cache_spec
157
+
158
+ def maybe_reinitialize_input_batch(self,
159
+ kv_cache_config: KVCacheConfig) -> None:
160
+ block_sizes = [
161
+ kv_cache_group.kv_cache_spec.block_size
162
+ for kv_cache_group in kv_cache_config.kv_cache_groups
163
+ ]
164
+ if block_sizes != [self.runner.cache_config.block_size]:
165
+ assert self.runner.cache_config.cpu_offload_gb == 0, (
166
+ "Cannot re-initialize the input batch when CPU weight "
167
+ "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
168
+ "for more details.")
169
+ new_input_batch = InputBatch(
170
+ max_num_reqs=self.runner.max_num_reqs,
171
+ max_model_len=self.runner.max_model_len,
172
+ max_num_batched_tokens=self.runner.max_num_tokens,
173
+ pin_memory=False,
174
+ vocab_size=self.runner.model_config.get_vocab_size(),
175
+ block_sizes=block_sizes,
176
+ )
177
+ self.runner.input_batch = new_input_batch
178
+ self.runner.persistent_batch_manager.input_batch = new_input_batch
179
+ self.runner.block_tables_cpu = [
180
+ np.zeros((self.runner.max_num_reqs,
181
+ cdiv(self.runner.max_model_len, block_size)),
182
+ dtype=np.int32) for block_size in block_sizes
183
+ ]
184
+
185
+ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
186
+ self.maybe_reinitialize_input_batch(kv_cache_config)
187
+
188
+ # uniform page size.
189
+ representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
190
+ page_size_bytes = representative_spec.page_size_bytes
191
+ self.runner.layer_name_to_kvcache_index: Dict[str, int] = {}
192
+ kv_caches = self.runner.kv_caches
193
+ num_blocks_list = []
194
+ for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
195
+ assert kv_cache_tensor.size % page_size_bytes == 0
196
+ num_blocks = kv_cache_tensor.size // page_size_bytes
197
+ dp_size = self.runner.vllm_config.sharding_config.total_dp_size
198
+ # num_blocks must be a multiple of dp_size
199
+ num_blocks = (num_blocks // dp_size) * dp_size
200
+ # NOTE: we'll multiply the num_kv_heads by 2 in the function
201
+ kv_cache = create_kv_caches(
202
+ num_blocks=num_blocks,
203
+ block_size=representative_spec.block_size,
204
+ num_kv_heads=representative_spec.num_kv_heads,
205
+ head_size=representative_spec.head_size,
206
+ mesh=self.runner.mesh,
207
+ layer_names=[f'kv_cache_tensor.{i}'],
208
+ cache_dtype=t2j_dtype(representative_spec.dtype),
209
+ )[0]
210
+ kv_caches.append(kv_cache)
211
+ num_blocks_list.append(num_blocks)
212
+ for layer_name in kv_cache_tensor.shared_by:
213
+ self.runner.layer_name_to_kvcache_index[layer_name] = i
214
+
215
+ if self.shared_kv_cache_layers:
216
+ for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
217
+ ):
218
+ self.runner.layer_name_to_kvcache_index[
219
+ layer_name] = self.runner.layer_name_to_kvcache_index[
220
+ target_layer_name]
221
+
222
+ logger.info(
223
+ f"Init kv-cache | "
224
+ f"num_layers={len(kv_caches)} | "
225
+ f"shape=(num_blocks, {kv_caches[0].shape[1:]}) | "
226
+ f"num_blocks={num_blocks_list} | "
227
+ f"sharding={kv_caches[0].sharding} | "
228
+ f"dtype={kv_caches[0].dtype} | "
229
+ f"hbm={utils.hbm_usage_gb(self.runner.mesh.devices.flatten())}Gb")
230
+
231
+ @staticmethod
232
+ @functools.partial(jax.jit)
233
+ def _jitted_gather_kv_cache(kv_caches: List[jax.Array],
234
+ block_ids: jax.Array) -> List[jax.Array]:
235
+ """
236
+ JIT-compiled function to gather KV cache slices for all layers at once.
237
+ This uses jax.tree.map to apply the operation across all layers.
238
+ """
239
+
240
+ def gather_and_reshape(layer_kv_cache):
241
+ return layer_kv_cache.at[block_ids].get().reshape(
242
+ -1, *layer_kv_cache.shape[2:])
243
+
244
+ return jax.tree.map(gather_and_reshape, kv_caches)
245
+
246
+ @staticmethod
247
+ @functools.partial(
248
+ jax.jit,
249
+ static_argnames=("len_block"),
250
+ )
251
+ def _jitted_gather_continuous_kv_cache(kv_caches: List[jax.Array],
252
+ start_block,
253
+ len_block) -> List[jax.Array]:
254
+ """
255
+ JIT-compiled function to gather KV cache slices for all layers at once.
256
+ This uses jax.tree.map to apply the operation across all layers.
257
+ """
258
+
259
+ def gather_and_reshape(layer_kv_cache):
260
+ shape = layer_kv_cache.shape
261
+ return jax.lax.dynamic_slice_in_dim(layer_kv_cache,
262
+ start_block,
263
+ len_block,
264
+ axis=0).reshape(
265
+ -1, *shape[2:])
266
+
267
+ return jax.tree.map(gather_and_reshape, kv_caches)
268
+
269
+ @staticmethod
270
+ @functools.partial(
271
+ jax.jit,
272
+ static_argnames=("block_size"),
273
+ donate_argnames=(
274
+ "kv_caches",
275
+ "kv_cache_slices",
276
+ ),
277
+ )
278
+ def _jitted_insert_kv_cache(
279
+ block_size,
280
+ kv_caches: List[jax.Array],
281
+ kv_cache_slices: List[jax.Array],
282
+ block_numbers: jax.Array,
283
+ ) -> List[jax.Array]:
284
+ """
285
+ JIT-compiled function to insert KV cache slices into the physical
286
+ cache for all layers at once. This fuses the pad, reshape, and scatter
287
+ operations into a single efficient kernel.
288
+ """
289
+
290
+ def _update_layer(cache, slices):
291
+ """The function to apply to each layer's cache and slices."""
292
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
293
+ cache.at[block_numbers].set(reshaped_slices)
294
+ return cache
295
+
296
+ return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
297
+
298
+ @staticmethod
299
+ @functools.partial(
300
+ jax.jit,
301
+ static_argnames=("block_size"),
302
+ donate_argnames=(
303
+ "kv_caches",
304
+ "kv_cache_slices",
305
+ ),
306
+ )
307
+ def _jitted_insert_continuous_kv_cache(
308
+ block_size,
309
+ kv_caches: List[jax.Array],
310
+ kv_cache_slices: List[jax.Array],
311
+ start_block,
312
+ ) -> List[jax.Array]:
313
+ """
314
+ JIT-compiled function to insert KV cache slices into continuous blocks.
315
+ Makes use of dynamic_update_slice_in_dim.
316
+ """
317
+
318
+ def _update_layer(cache, slices):
319
+ """The function to apply to each layer's cache and slices."""
320
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
321
+
322
+ return jax.lax.dynamic_update_slice_in_dim(cache,
323
+ reshaped_slices,
324
+ start_block,
325
+ axis=0)
326
+
327
+ return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
328
+
329
+ def get_kv_cache_for_block_ids(
330
+ self,
331
+ block_ids: List[int],
332
+ ) -> List[jax.Array]:
333
+ """
334
+ Extracts the KV cache slices for a given list of block IDs.
335
+ This assumes all provided blocks are full.
336
+
337
+ Args:
338
+ block_ids: A list of block IDs to extract KV cache for.
339
+
340
+ Returns:
341
+ A list of JAX arrays, with each array representing the KV cache
342
+ slices for a layer, concatenated for all blocks.
343
+ """
344
+ if block_ids == list(range(block_ids[0],
345
+ block_ids[0] + len(block_ids))):
346
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
347
+ self.runner.kv_caches, block_ids[0], len(block_ids))
348
+
349
+ else:
350
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
351
+ self.runner.kv_caches, jnp.array(block_ids))
352
+ return batched_kv_cache_per_layer
353
+
354
+ def transfer_kv_cache(self,
355
+ kv_cache_slices: List[jax.Array]) -> List[jax.Array]:
356
+ """
357
+ Transfers KV cache slices to the runner's mesh.
358
+
359
+ This is used when a KV cache generated on one runner (e.g., a prefill
360
+ runner) needs to be used on another runner (e.g., a decode runner)
361
+ with a different device mesh. The transfer is asynchronous.
362
+
363
+ Args:
364
+ kv_cache_slices: A list of JAX arrays, where each array contains
365
+ the KV cache slices for a specific layer. The shape of each
366
+ slice is expected to be (num_tokens, num_kv_heads * 2, head_size).
367
+
368
+ Returns:
369
+ A new list of JAX arrays representing the KV cache slices, sharded
370
+ across the runner's device mesh.
371
+ """
372
+ # The KV cache slices have a shape of (num_tokens, num_kv_heads * 2, head_size).
373
+ # We shard along the num_kv_heads dimension (axis=1), which corresponds
374
+ # to the "model" axis of the mesh for tensor parallelism.
375
+ logger.debug(
376
+ f"Transferring kv cache shape {len(kv_cache_slices)} * {kv_cache_slices[0].shape} sharding {kv_cache_slices[0].sharding} size {kv_cache_slices[0].nbytes * len(kv_cache_slices)/1024/1024} Mbytes"
377
+ )
378
+ sharding = NamedSharding(self.runner.mesh,
379
+ PartitionSpec(None, "model"))
380
+ if envs.VLLM_TPU_USING_PATHWAYS:
381
+ from pathwaysutils.experimental import \
382
+ reshard as experimental_reshard
383
+
384
+ def get_sharding(x):
385
+ return sharding
386
+
387
+ sharding_spec_pytree = jax.tree.map(get_sharding, kv_cache_slices)
388
+ transferred_kv_cache = experimental_reshard.reshard(
389
+ kv_cache_slices,
390
+ sharding_spec_pytree,
391
+ donate=False,
392
+ )
393
+ else:
394
+ transferred_kv_cache = jax.device_put(kv_cache_slices, sharding)
395
+
396
+ jax.block_until_ready(transferred_kv_cache)
397
+ return transferred_kv_cache
398
+
399
+ def insert_request_with_kv_cache(
400
+ self,
401
+ request: "Request",
402
+ kv_cache_slices: List[jax.Array],
403
+ block_ids: List[List[int]],
404
+ ):
405
+ """
406
+ Inserts a request and its KV cache into the runner. This is used to
407
+ transfer a request from a prefill runner to a decode runner.
408
+
409
+ The provided KV cache slices are copied into the physical blocks
410
+ allocated for the request. The runner's internal state is then updated
411
+ to include the request.
412
+
413
+ Args:
414
+ request: The vLLM request object, containing the state after prefill.
415
+ kv_cache_slices: The KV cache for the request, already transferred
416
+ to this runner's mesh. This is a list of JAX arrays, one per layer.
417
+ block_ids: The physical block numbers allocated for this request on
418
+ this runner. This is a list of lists, for each KV cache group.
419
+ """
420
+ # Assume one KV cache group for now, which is consistent with current setup.
421
+ if len(block_ids) > 1:
422
+ raise NotImplementedError(
423
+ "Inserting KV cache for models with multiple KV cache groups "
424
+ "is not supported yet.")
425
+ block_numbers = block_ids[0]
426
+ if block_numbers == list(
427
+ range(block_numbers[0],
428
+ block_numbers[0] + len(block_numbers))):
429
+ # For continuous blocks we use slice instead of scatter.
430
+ start_block = block_numbers[0]
431
+ with runner_utils.LatencyTracker(
432
+ f"JittedInsertContinuousKVCache-b{len(block_numbers)}"):
433
+ logger.debug(f"inserting to continuous blocks {block_numbers}")
434
+ self.runner.kv_caches = KVCacheManager._jitted_insert_continuous_kv_cache(
435
+ self.runner.block_size,
436
+ self.runner.kv_caches,
437
+ kv_cache_slices,
438
+ start_block,
439
+ )
440
+ jax.block_until_ready(self.runner.kv_caches)
441
+ else:
442
+ with runner_utils.LatencyTracker(
443
+ f"JittedInsertKVCache-b{len(block_numbers)}"):
444
+ logger.debug(
445
+ f"inserting to non continuous blocks {block_numbers}")
446
+ self.runner.kv_caches = KVCacheManager._jitted_insert_kv_cache(
447
+ self.runner.block_size,
448
+ self.runner.kv_caches,
449
+ kv_cache_slices,
450
+ jnp.array(block_numbers),
451
+ )
452
+ jax.block_until_ready(self.runner.kv_caches)
453
+
454
+ logger.debug(
455
+ f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
456
+
457
+ # Update runner's internal state to track the new request.
458
+ req_id = request.request_id
459
+ if req_id in self.runner.requests:
460
+ logger.warning(
461
+ f"Request {req_id} already exists in the runner. Overwriting.")
462
+
463
+ # Create a CachedRequestState object to add to the input batch.
464
+ req_state = CachedRequestState(
465
+ req_id=request.request_id,
466
+ prompt_token_ids=request.prompt_token_ids,
467
+ output_token_ids=[request.all_token_ids[-1]],
468
+ sampling_params=request.sampling_params,
469
+ block_ids=tuple(block_ids),
470
+ num_computed_tokens=request.num_computed_tokens,
471
+ lora_request=request.lora_request,
472
+ mm_features=getattr(request, "mm_features", []),
473
+ pooling_params=getattr(request, "pooling_params", None),
474
+ generator=None,
475
+ )
476
+
477
+ self.runner.requests[req_id] = req_state
478
+ self.runner.input_batch.add_request(req_state)
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ from torchax.interop import jax_view
7
+ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
8
+ from vllm.lora.request import LoRARequest
9
+
10
+ from tpu_inference.layers.vllm.sharding import update_lora
11
+
12
+ if TYPE_CHECKING:
13
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
14
+
15
+
16
+ class LoraUtils:
17
+
18
+ def __init__(self, runner: "TPUModelRunner"):
19
+ self.runner = runner
20
+
21
+ def set_active_loras(self, num_scheduled_tokens_per_req,
22
+ total_num_scheduled_tokens,
23
+ padded_total_num_scheduled_tokens):
24
+ # We need to respect padding when activating LoRA adapters
25
+ padded_num_scheduled_tokens_per_req = np.copy(
26
+ num_scheduled_tokens_per_req
27
+ ) # Copying to avoid accidental state corruption bugs
28
+ padded_num_scheduled_tokens_per_req[-1] += \
29
+ padded_total_num_scheduled_tokens - total_num_scheduled_tokens
30
+
31
+ prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
32
+ token_lora_mapping: tuple[int,
33
+ ...] # of size np.sum(num_scheduled_tokens)
34
+ lora_requests: set[LoRARequest]
35
+ prompt_lora_mapping, token_lora_mapping, lora_requests = \
36
+ self.runner.input_batch.make_lora_inputs(padded_num_scheduled_tokens_per_req)
37
+ # One should not put lora_manager.set_active_loras under
38
+ # torchax.default_env() because set_active_loras also load lora from
39
+ # disk and torchax currently does not support that. Here we load the
40
+ # lora and set the lora weight to the linear layers.
41
+ self.runner._set_active_loras(prompt_lora_mapping, token_lora_mapping,
42
+ lora_requests)
43
+
44
+ params_and_buffers = update_lora(
45
+ self.runner.model.model, initial_params_buffers=self.runner.state)
46
+ self.runner.state = jax_view(params_and_buffers)
47
+
48
+ def extract_lora_metadata(self):
49
+ if self.runner.lora_config is None:
50
+ return None
51
+
52
+ metadata = {}
53
+ punica_wrapper = None
54
+ for _, m in self.runner.model.model.named_modules():
55
+ if isinstance(m, BaseLinearLayerWithLoRA):
56
+ assert getattr(
57
+ m, 'punica_wrapper', None
58
+ ) is not None, 'A lora wrapper should have contained a punica_wrapper'
59
+ punica_wrapper = m.punica_wrapper
60
+ break
61
+ assert punica_wrapper is not None, 'Should have been able to find a punica wrapper from the Lora wrapper.'
62
+
63
+ # vars does not show inherited methods or class attributes but this is
64
+ # fine because we only care about instance attributes.
65
+ for k in vars(punica_wrapper):
66
+ v = getattr(punica_wrapper, k, None)
67
+ if k == 'device': # Exclude string as it can't be traced by jax.jit
68
+ continue
69
+ metadata[k] = v
70
+ return jax_view(metadata)
71
+
72
+
73
+ def replace_lora_metadata(model, metadata: dict, lora_config) -> dict:
74
+ if lora_config is None or not metadata:
75
+ return {}
76
+
77
+ original_metadata = {}
78
+ punica_wrapper = None
79
+ for _, m in model.named_modules():
80
+ if isinstance(m, BaseLinearLayerWithLoRA):
81
+ assert getattr(
82
+ m, 'punica_wrapper', None
83
+ ) is not None, 'A lora wrapper should have contained a punica_wrapper'
84
+ punica_wrapper = m.punica_wrapper
85
+ break
86
+ assert punica_wrapper is not None, 'Should have been able to find a punica wrapper from the Lora wrapper.'
87
+
88
+ for k in vars(punica_wrapper):
89
+ if k in metadata:
90
+ original_metadata[k] = getattr(punica_wrapper, k)
91
+ setattr(punica_wrapper, k, metadata[k])
92
+ return original_metadata