tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,479 @@
1
+ import functools
2
+ import math
3
+ from typing import TYPE_CHECKING, Dict, List
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import vllm.envs as envs
8
+ from jax.sharding import NamedSharding, PartitionSpec
9
+ from torchax.ops.mappings import t2j_dtype
10
+ from vllm.attention import Attention
11
+ from vllm.attention.backends.abstract import AttentionType
12
+ from vllm.config import get_layers_from_vllm_config
13
+ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
14
+ KVCacheSpec, MLAAttentionSpec,
15
+ SlidingWindowSpec)
16
+
17
+ from tpu_inference import utils
18
+ from tpu_inference import utils as common_utils
19
+ from tpu_inference.logger import init_logger
20
+ from tpu_inference.runner import utils as runner_utils
21
+ from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
22
+ from tpu_inference.runner.kv_cache import create_kv_caches
23
+
24
+ if TYPE_CHECKING:
25
+ from vllm.v1.request import Request
26
+
27
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
28
+
29
+ logger = init_logger(__name__)
30
+
31
+
32
+ class KVCacheManager:
33
+
34
+ def __init__(self, runner: "TPUModelRunner"):
35
+ self.runner = runner
36
+ # Layer pairings for cross-layer KV sharing.
37
+ # If an Attention layer `layer_name` is in the keys of this dict, it
38
+ # means this layer will perform attention using the keys and values
39
+ # from the KV cache of `shared_kv_cache_layers[layer_name]`.
40
+ self.shared_kv_cache_layers: dict[str, str] = {}
41
+
42
+ def get_kv_cache_spec(self):
43
+ # TODO(xiang): this hack tricks engine core to init successfully
44
+ block_size = self.runner.cache_config.block_size
45
+ use_mla = self.runner.model_config.use_mla
46
+ kv_cache_spec: dict[str, KVCacheSpec] = {}
47
+
48
+ # If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
49
+ # attention into compilation config.
50
+ # Use FullAttentionSpec for each layer
51
+ # TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
52
+ if len(self.runner.vllm_config.compilation_config.
53
+ static_forward_context) == 0:
54
+ model_config = self.runner.model_config
55
+ parallel_config = self.runner.parallel_config
56
+ # Pad num_kv_heads to multiple of TP size.
57
+ num_kv_heads = common_utils.get_padded_num_heads(
58
+ model_config.get_total_num_kv_heads(),
59
+ self.runner.mesh.shape["model"])
60
+ head_size = common_utils.get_padded_head_dim(
61
+ model_config.get_head_size())
62
+ for i in range(model_config.get_num_layers(parallel_config)):
63
+ if use_mla:
64
+ kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
65
+ block_size=block_size,
66
+ num_kv_heads=num_kv_heads,
67
+ head_size=head_size,
68
+ dtype=self.runner.kv_cache_dtype,
69
+ cache_dtype_str=self.runner.vllm_config.cache_config.
70
+ cache_dtype)
71
+ else:
72
+ kv_cache_spec[f"layer.{i}"] = FullAttentionSpec(
73
+ block_size=block_size,
74
+ num_kv_heads=num_kv_heads,
75
+ head_size=head_size,
76
+ dtype=self.runner.kv_cache_dtype)
77
+ if self.runner.speculative_config and self.runner.speculative_config.method == "eagle3":
78
+ draft_model_config = self.runner.speculative_config.draft_model_config
79
+ hf_config = draft_model_config.hf_config
80
+ num_kv_heads = common_utils.get_padded_num_heads(
81
+ hf_config.num_key_value_heads,
82
+ self.runner.mesh.shape["model"])
83
+ head_size = common_utils.get_padded_head_dim(
84
+ hf_config.hidden_size // hf_config.num_attention_heads)
85
+
86
+ # Eagle3 has only 1 layer
87
+ for i in range(1):
88
+ if use_mla:
89
+ kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
90
+ block_size=block_size,
91
+ num_kv_heads=num_kv_heads,
92
+ head_size=head_size,
93
+ dtype=self.runner.kv_cache_dtype,
94
+ cache_dtype_str=self.runner.vllm_config.
95
+ cache_config.cache_dtype)
96
+ else:
97
+ kv_cache_spec[f"draft_layer.{i}"] = FullAttentionSpec(
98
+ block_size=block_size,
99
+ num_kv_heads=num_kv_heads,
100
+ head_size=head_size,
101
+ dtype=self.runner.kv_cache_dtype)
102
+ else:
103
+ # Else propagate attention modules from compilation config.
104
+ layers = get_layers_from_vllm_config(self.runner.vllm_config,
105
+ Attention)
106
+ for layer_name, attn_module in layers.items():
107
+ if (kv_tgt_layer :=
108
+ attn_module.kv_sharing_target_layer_name) is not None:
109
+ # The layer doesn't need its own KV cache and will use that of
110
+ # the target layer. We skip creating a KVCacheSpec for it, so
111
+ # that KV cache management logic will act as this layer does
112
+ # not exist, and doesn't allocate KV cache for the layer. This
113
+ # enables the memory saving of cross-layer kv sharing, allowing
114
+ # a given amount of memory to accommodate longer context lengths
115
+ # or enable more requests to be processed simultaneously.
116
+ self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
117
+ continue
118
+ if attn_module.attn_type == AttentionType.DECODER:
119
+ if attn_module.sliding_window is not None:
120
+ kv_cache_spec[layer_name] = SlidingWindowSpec(
121
+ block_size=block_size,
122
+ num_kv_heads=common_utils.get_padded_num_heads(
123
+ attn_module.num_kv_heads,
124
+ self.runner.mesh.shape["model"]),
125
+ head_size=common_utils.get_padded_head_dim(
126
+ attn_module.head_size),
127
+ dtype=self.runner.kv_cache_dtype,
128
+ sliding_window=attn_module.sliding_window)
129
+ elif use_mla:
130
+ kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
131
+ block_size=block_size,
132
+ num_kv_heads=attn_module.num_kv_heads,
133
+ head_size=attn_module.head_size,
134
+ dtype=self.runner.kv_cache_dtype,
135
+ cache_dtype_str=self.runner.vllm_config.
136
+ cache_config.cache_dtype)
137
+ else:
138
+ kv_cache_spec[layer_name] = FullAttentionSpec(
139
+ block_size=block_size,
140
+ num_kv_heads=common_utils.get_padded_num_heads(
141
+ attn_module.num_kv_heads,
142
+ self.runner.mesh.shape["model"]),
143
+ head_size=common_utils.get_padded_head_dim(
144
+ attn_module.head_size),
145
+ dtype=self.runner.kv_cache_dtype)
146
+ elif attn_module.attn_type in (AttentionType.ENCODER,
147
+ AttentionType.ENCODER_ONLY):
148
+ # encoder-only attention does not need KV cache.
149
+ continue
150
+ elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
151
+ raise NotImplementedError
152
+ else:
153
+ raise ValueError(
154
+ f"Unknown attention type: {attn_module.attn_type}")
155
+ return kv_cache_spec
156
+
157
+ def maybe_reinitialize_input_batch(self,
158
+ kv_cache_config: KVCacheConfig) -> None:
159
+ block_sizes = [
160
+ kv_cache_group.kv_cache_spec.block_size
161
+ for kv_cache_group in kv_cache_config.kv_cache_groups
162
+ ]
163
+ if block_sizes != [self.runner.cache_config.block_size]:
164
+ assert self.runner.cache_config.cpu_offload_gb == 0, (
165
+ "Cannot re-initialize the input batch when CPU weight "
166
+ "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
167
+ "for more details.")
168
+ new_input_batch = InputBatch(
169
+ max_num_reqs=self.runner.max_num_reqs,
170
+ max_model_len=self.runner.max_model_len,
171
+ max_num_batched_tokens=self.runner.max_num_tokens,
172
+ pin_memory=False,
173
+ vocab_size=self.runner.model_config.get_vocab_size(),
174
+ block_sizes=block_sizes,
175
+ )
176
+ self.runner.input_batch = new_input_batch
177
+ self.runner.persistent_batch_manager.input_batch = new_input_batch
178
+
179
+ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
180
+ self.maybe_reinitialize_input_batch(kv_cache_config)
181
+
182
+ # uniform page size.
183
+ representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
184
+ page_size_bytes = representative_spec.page_size_bytes
185
+ self.runner.layer_name_to_kvcache_index: Dict[str, int] = {}
186
+ kv_caches = self.runner.kv_caches
187
+ num_blocks_list = []
188
+ for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
189
+ assert kv_cache_tensor.size % page_size_bytes == 0
190
+ num_blocks = kv_cache_tensor.size // page_size_bytes
191
+ dp_size = self.runner.vllm_config.sharding_config.total_dp_size
192
+ # num_blocks must be a multiple of dp_size
193
+ num_blocks = math.ceil(num_blocks / dp_size) * dp_size
194
+ # NOTE: we'll multiply the num_kv_heads by 2 in the function
195
+ kv_cache = create_kv_caches(
196
+ num_blocks=num_blocks,
197
+ block_size=representative_spec.block_size,
198
+ num_kv_heads=representative_spec.num_kv_heads,
199
+ head_size=representative_spec.head_size,
200
+ mesh=self.runner.mesh,
201
+ layer_names=[f'kv_cache_tensor.{i}'],
202
+ cache_dtype=t2j_dtype(representative_spec.dtype),
203
+ )[0]
204
+ kv_caches.append(kv_cache)
205
+ num_blocks_list.append(num_blocks)
206
+ for layer_name in kv_cache_tensor.shared_by:
207
+ self.runner.layer_name_to_kvcache_index[layer_name] = i
208
+
209
+ if self.shared_kv_cache_layers:
210
+ for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
211
+ ):
212
+ self.runner.layer_name_to_kvcache_index[
213
+ layer_name] = self.runner.layer_name_to_kvcache_index[
214
+ target_layer_name]
215
+
216
+ logger.info(
217
+ f"Init kv-cache | "
218
+ f"num_layers={len(kv_caches)} | "
219
+ f"shape=(num_blocks, {kv_caches[0].shape[1:]}) | "
220
+ f"num_blocks={num_blocks_list} | "
221
+ f"sharding={kv_caches[0].sharding} | "
222
+ f"dtype={kv_caches[0].dtype} | "
223
+ f"hbm={utils.hbm_usage_gb(self.runner.mesh.devices.flatten())}Gb")
224
+
225
+ @staticmethod
226
+ @functools.partial(jax.jit)
227
+ def _jitted_gather_kv_cache(kv_caches: List[jax.Array],
228
+ block_ids: jax.Array) -> List[jax.Array]:
229
+ """
230
+ JIT-compiled function to gather KV cache slices for all layers at once.
231
+ This uses jax.tree.map to apply the operation across all layers.
232
+ """
233
+
234
+ def gather_and_reshape(layer_kv_cache):
235
+ return layer_kv_cache.at[block_ids].get().reshape(
236
+ -1, *layer_kv_cache.shape[2:])
237
+
238
+ return jax.tree.map(gather_and_reshape, kv_caches)
239
+
240
+ @staticmethod
241
+ @functools.partial(
242
+ jax.jit,
243
+ static_argnames=("len_block"),
244
+ )
245
+ def _jitted_gather_continuous_kv_cache(kv_caches: List[jax.Array],
246
+ start_block,
247
+ len_block) -> List[jax.Array]:
248
+ """
249
+ JIT-compiled function to gather KV cache slices for all layers at once.
250
+ This uses jax.tree.map to apply the operation across all layers.
251
+ """
252
+
253
+ def gather_and_reshape(layer_kv_cache):
254
+ shape = layer_kv_cache.shape
255
+ return jax.lax.dynamic_slice_in_dim(layer_kv_cache,
256
+ start_block,
257
+ len_block,
258
+ axis=0).reshape(
259
+ -1, *shape[2:])
260
+
261
+ return jax.tree.map(gather_and_reshape, kv_caches)
262
+
263
+ @staticmethod
264
+ @functools.partial(
265
+ jax.jit,
266
+ static_argnames=("block_size"),
267
+ donate_argnames=(
268
+ "kv_caches",
269
+ "kv_cache_slices",
270
+ ),
271
+ )
272
+ def _jitted_insert_kv_cache(
273
+ block_size,
274
+ kv_caches: List[jax.Array],
275
+ kv_cache_slices: List[jax.Array],
276
+ block_numbers: jax.Array,
277
+ ) -> List[jax.Array]:
278
+ """
279
+ JIT-compiled function to insert KV cache slices into the physical
280
+ cache for all layers at once. This fuses the pad, reshape, and scatter
281
+ operations into a single efficient kernel.
282
+ """
283
+
284
+ def _update_layer(cache, slices):
285
+ """The function to apply to each layer's cache and slices."""
286
+ reshaped_slices = slices.reshape(-1, 1, block_size,
287
+ *slices.shape[1:])
288
+ for (i, block_idx) in enumerate(block_numbers):
289
+ cache = jax.lax.dynamic_update_slice_in_dim(cache,
290
+ reshaped_slices[i],
291
+ block_idx,
292
+ axis=0)
293
+ return cache
294
+
295
+ return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
296
+
297
+ @staticmethod
298
+ @functools.partial(
299
+ jax.jit,
300
+ static_argnames=("block_size"),
301
+ donate_argnames=(
302
+ "kv_caches",
303
+ "kv_cache_slices",
304
+ ),
305
+ )
306
+ def _jitted_insert_continuous_kv_cache(
307
+ block_size,
308
+ kv_caches: List[jax.Array],
309
+ kv_cache_slices: List[jax.Array],
310
+ start_block,
311
+ ) -> List[jax.Array]:
312
+ """
313
+ JIT-compiled function to insert KV cache slices into continuous blocks.
314
+ Makes use of dynamic_update_slice_in_dim.
315
+ """
316
+
317
+ def _update_layer(cache, slices):
318
+ """The function to apply to each layer's cache and slices."""
319
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
320
+
321
+ return jax.lax.dynamic_update_slice_in_dim(cache,
322
+ reshaped_slices,
323
+ start_block,
324
+ axis=0)
325
+
326
+ return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
327
+
328
+ def get_kv_cache_for_block_ids(
329
+ self,
330
+ block_ids: List[int],
331
+ ) -> List[jax.Array]:
332
+ """
333
+ Extracts the KV cache slices for a given list of block IDs.
334
+ This assumes all provided blocks are full.
335
+
336
+ Args:
337
+ block_ids: A list of block IDs to extract KV cache for.
338
+
339
+ Returns:
340
+ A list of JAX arrays, with each array representing the KV cache
341
+ slices for a layer, concatenated for all blocks.
342
+ """
343
+ if block_ids == list(range(block_ids[0],
344
+ block_ids[0] + len(block_ids))):
345
+ with runner_utils.LatencyTracker(
346
+ "BatchedGatherKVSlices-for-blocks"):
347
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
348
+ self.runner.kv_caches, block_ids[0], len(block_ids))
349
+
350
+ else:
351
+ with runner_utils.LatencyTracker(
352
+ "BatchedGatherKVSlices-for-blocks"):
353
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
354
+ self.runner.kv_caches, jnp.array(block_ids))
355
+ return batched_kv_cache_per_layer
356
+
357
+ def transfer_kv_cache(self,
358
+ kv_cache_slices: List[jax.Array]) -> List[jax.Array]:
359
+ """
360
+ Transfers KV cache slices to the runner's mesh.
361
+
362
+ This is used when a KV cache generated on one runner (e.g., a prefill
363
+ runner) needs to be used on another runner (e.g., a decode runner)
364
+ with a different device mesh. The transfer is asynchronous.
365
+
366
+ Args:
367
+ kv_cache_slices: A list of JAX arrays, where each array contains
368
+ the KV cache slices for a specific layer. The shape of each
369
+ slice is expected to be (num_tokens, num_kv_heads * 2, head_size).
370
+
371
+ Returns:
372
+ A new list of JAX arrays representing the KV cache slices, sharded
373
+ across the runner's device mesh.
374
+ """
375
+ # The KV cache slices have a shape of (num_tokens, num_kv_heads * 2, head_size).
376
+ # We shard along the num_kv_heads dimension (axis=1), which corresponds
377
+ # to the "model" axis of the mesh for tensor parallelism.
378
+ logger.debug(
379
+ f"Transferring kv cache shape {len(kv_cache_slices)} * {kv_cache_slices[0].shape} sharding {kv_cache_slices[0].sharding} size {kv_cache_slices[0].nbytes * len(kv_cache_slices)/1024/1024} Mbytes"
380
+ )
381
+ sharding = NamedSharding(self.runner.mesh,
382
+ PartitionSpec(None, "model"))
383
+ if envs.VLLM_TPU_USING_PATHWAYS:
384
+ from pathwaysutils.experimental import \
385
+ reshard as experimental_reshard
386
+
387
+ def get_sharding(x):
388
+ return sharding
389
+
390
+ sharding_spec_pytree = jax.tree.map(get_sharding, kv_cache_slices)
391
+ transferred_kv_cache = experimental_reshard.reshard(
392
+ kv_cache_slices,
393
+ sharding_spec_pytree,
394
+ donate=False,
395
+ )
396
+ else:
397
+ transferred_kv_cache = jax.device_put(kv_cache_slices, sharding)
398
+
399
+ jax.block_until_ready(transferred_kv_cache)
400
+ return transferred_kv_cache
401
+
402
+ def insert_request_with_kv_cache(
403
+ self,
404
+ request: "Request",
405
+ kv_cache_slices: List[jax.Array],
406
+ block_ids: List[List[int]],
407
+ ):
408
+ """
409
+ Inserts a request and its KV cache into the runner. This is used to
410
+ transfer a request from a prefill runner to a decode runner.
411
+
412
+ The provided KV cache slices are copied into the physical blocks
413
+ allocated for the request. The runner's internal state is then updated
414
+ to include the request.
415
+
416
+ Args:
417
+ request: The vLLM request object, containing the state after prefill.
418
+ kv_cache_slices: The KV cache for the request, already transferred
419
+ to this runner's mesh. This is a list of JAX arrays, one per layer.
420
+ block_ids: The physical block numbers allocated for this request on
421
+ this runner. This is a list of lists, for each KV cache group.
422
+ """
423
+ # Assume one KV cache group for now, which is consistent with current setup.
424
+ if len(block_ids) > 1:
425
+ raise NotImplementedError(
426
+ "Inserting KV cache for models with multiple KV cache groups "
427
+ "is not supported yet.")
428
+ block_numbers = block_ids[0]
429
+ if block_numbers == list(
430
+ range(block_numbers[0],
431
+ block_numbers[0] + len(block_numbers))):
432
+ # For continuous blocks we use slice instead of scatter.
433
+ start_block = block_numbers[0]
434
+ with runner_utils.LatencyTracker(
435
+ f"JittedInsertContinuousKVCache-b{len(block_numbers)}"):
436
+ logger.debug(f"inserting to continuous blocks {block_numbers}")
437
+ self.runner.kv_caches = KVCacheManager._jitted_insert_continuous_kv_cache(
438
+ self.runner.block_size,
439
+ self.runner.kv_caches,
440
+ kv_cache_slices,
441
+ start_block,
442
+ )
443
+ else:
444
+ with runner_utils.LatencyTracker(
445
+ f"JittedInsertKVCache-b{len(block_numbers)}"):
446
+ logger.debug(
447
+ f"inserting to non continuous blocks {block_numbers}")
448
+ self.runner.kv_caches = KVCacheManager._jitted_insert_kv_cache(
449
+ self.runner.block_size,
450
+ self.runner.kv_caches,
451
+ kv_cache_slices,
452
+ jnp.array(block_numbers),
453
+ )
454
+
455
+ logger.debug(
456
+ f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
457
+
458
+ # Update runner's internal state to track the new request.
459
+ req_id = request.request_id
460
+ if req_id in self.runner.requests:
461
+ logger.warning(
462
+ f"Request {req_id} already exists in the runner. Overwriting.")
463
+
464
+ # Create a CachedRequestState object to add to the input batch.
465
+ req_state = CachedRequestState(
466
+ req_id=request.request_id,
467
+ prompt_token_ids=request.prompt_token_ids,
468
+ output_token_ids=[request.all_token_ids[-1]],
469
+ sampling_params=request.sampling_params,
470
+ block_ids=tuple(block_ids),
471
+ num_computed_tokens=request.num_computed_tokens,
472
+ lora_request=request.lora_request,
473
+ mm_features=getattr(request, "mm_features", []),
474
+ pooling_params=getattr(request, "pooling_params", None),
475
+ generator=None,
476
+ )
477
+
478
+ self.runner.requests[req_id] = req_state
479
+ self.runner.input_batch.add_request(req_state)
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ from torchax.interop import jax_view
7
+ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
8
+ from vllm.lora.request import LoRARequest
9
+
10
+ from tpu_inference.layers.vllm.sharding import update_lora
11
+
12
+ if TYPE_CHECKING:
13
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
14
+
15
+
16
+ class LoraUtils:
17
+
18
+ def __init__(self, runner: "TPUModelRunner"):
19
+ self.runner = runner
20
+
21
+ def set_active_loras(self, num_scheduled_tokens_per_req,
22
+ total_num_scheduled_tokens,
23
+ padded_total_num_scheduled_tokens):
24
+ # We need to respect padding when activating LoRA adapters
25
+ padded_num_scheduled_tokens_per_req = np.copy(
26
+ num_scheduled_tokens_per_req
27
+ ) # Copying to avoid accidental state corruption bugs
28
+ padded_num_scheduled_tokens_per_req[-1] += \
29
+ padded_total_num_scheduled_tokens - total_num_scheduled_tokens
30
+
31
+ prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
32
+ token_lora_mapping: tuple[int,
33
+ ...] # of size np.sum(num_scheduled_tokens)
34
+ lora_requests: set[LoRARequest]
35
+ prompt_lora_mapping, token_lora_mapping, lora_requests = \
36
+ self.runner.input_batch.make_lora_inputs(padded_num_scheduled_tokens_per_req)
37
+ # One should not put lora_manager.set_active_loras under
38
+ # torchax.default_env() because set_active_loras also load lora from
39
+ # disk and torchax currently does not support that. Here we load the
40
+ # lora and set the lora weight to the linear layers.
41
+ self.runner._set_active_loras(prompt_lora_mapping, token_lora_mapping,
42
+ lora_requests)
43
+
44
+ params_and_buffers = update_lora(
45
+ self.runner.model.model, initial_params_buffers=self.runner.state)
46
+ self.runner.state = jax_view(params_and_buffers)
47
+
48
+ def extract_lora_metadata(self):
49
+ if self.runner.lora_config is None:
50
+ return None
51
+
52
+ metadata = {}
53
+ punica_wrapper = None
54
+ for _, m in self.runner.model.model.named_modules():
55
+ if isinstance(m, BaseLinearLayerWithLoRA):
56
+ assert getattr(
57
+ m, 'punica_wrapper', None
58
+ ) is not None, 'A lora wrapper should have contained a punica_wrapper'
59
+ punica_wrapper = m.punica_wrapper
60
+ break
61
+ assert punica_wrapper is not None, 'Should have been able to find a punica wrapper from the Lora wrapper.'
62
+
63
+ # vars does not show inherited methods or class attributes but this is
64
+ # fine because we only care about instance attributes.
65
+ for k in vars(punica_wrapper):
66
+ v = getattr(punica_wrapper, k, None)
67
+ if k == 'device': # Exclude string as it can't be traced by jax.jit
68
+ continue
69
+ metadata[k] = v
70
+ return jax_view(metadata)
71
+
72
+
73
+ def replace_lora_metadata(model, metadata: dict, lora_config) -> dict:
74
+ if lora_config is None or not metadata:
75
+ return {}
76
+
77
+ original_metadata = {}
78
+ punica_wrapper = None
79
+ for _, m in model.named_modules():
80
+ if isinstance(m, BaseLinearLayerWithLoRA):
81
+ assert getattr(
82
+ m, 'punica_wrapper', None
83
+ ) is not None, 'A lora wrapper should have contained a punica_wrapper'
84
+ punica_wrapper = m.punica_wrapper
85
+ break
86
+ assert punica_wrapper is not None, 'Should have been able to find a punica wrapper from the Lora wrapper.'
87
+
88
+ for k in vars(punica_wrapper):
89
+ if k in metadata:
90
+ original_metadata[k] = getattr(punica_wrapper, k)
91
+ setattr(punica_wrapper, k, metadata[k])
92
+ return original_metadata