tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,293 @@
1
+ import copy
2
+ import functools
3
+ from collections.abc import Sequence
4
+ from contextlib import nullcontext
5
+ from typing import Any, List, Optional, Tuple
6
+ from unittest.mock import patch
7
+
8
+ import jax
9
+ import torch
10
+ import torch.nn
11
+ import torchax
12
+ import vllm.envs as vllm_envs
13
+ from flax.typing import PRNGKey
14
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
15
+ from torchax.interop import jax_view, torch_view
16
+ from torchax.ops.mappings import TORCH_DTYPE_TO_JAX
17
+ from vllm.config import VllmConfig
18
+ from vllm.forward_context import set_forward_context
19
+ from vllm.lora.layers import BaseLayerWithLoRA
20
+ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
21
+ from vllm.model_executor.model_loader import get_model as vllm_get_model
22
+ from vllm.model_executor.models import supports_lora, supports_multimodal
23
+ from vllm.sequence import IntermediateTensors
24
+
25
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
26
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
27
+ from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
28
+ from tpu_inference.logger import init_logger
29
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
30
+ JaxIntermediateTensors
31
+ from tpu_inference.models.vllm.vllm_model_wrapper_context import (
32
+ get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
33
+ from tpu_inference.runner.lora_utils import replace_lora_metadata
34
+
35
+ logger = init_logger(__name__)
36
+
37
+
38
+ class _VllmRunner(torch.nn.Module):
39
+
40
+ def __init__(self, vllm_model: torch.nn.Module):
41
+ super().__init__()
42
+ self.vllm_model = vllm_model
43
+
44
+ def forward(self, **kwargs) -> torch.Tensor:
45
+ if "hidden_state" in kwargs:
46
+ return self.compute_logits(kwargs["hidden_state"])
47
+ else:
48
+ return self.compute_hidden_state(
49
+ kwargs["input_ids"],
50
+ kwargs["positions"],
51
+ kwargs["intermediate_tensors"],
52
+ kwargs["inputs_embeds"],
53
+ )
54
+
55
+ def compute_hidden_state(
56
+ self,
57
+ input_ids: torch.Tensor,
58
+ positions: torch.Tensor,
59
+ intermediate_tensors: Optional[IntermediateTensors],
60
+ inputs_embeds: Optional[torch.Tensor],
61
+ ) -> torch.Tensor:
62
+ hidden_state = self.vllm_model(input_ids, positions,
63
+ intermediate_tensors, inputs_embeds)
64
+ return hidden_state
65
+
66
+ def compute_logits(self, hidden_state: torch.Tensor) -> torch.Tensor:
67
+ return self.vllm_model.compute_logits(hidden_state)
68
+
69
+
70
+ class VllmModelWrapper:
71
+ """ Wraps a vLLM Pytorch model and let it run on the JAX engine. """
72
+
73
+ rng: PRNGKey
74
+ mesh: Mesh
75
+ model: _VllmRunner
76
+
77
+ def __init__(self, vllm_config: VllmConfig, rng: PRNGKey, mesh: Mesh):
78
+ self.vllm_config = vllm_config
79
+ self.rng = rng
80
+ self.mesh = mesh
81
+
82
+ self.vllm_config.quant_config = get_tpu_quantization_config(
83
+ self.vllm_config, self.mesh)
84
+
85
+ def load_weights(self):
86
+ # Set up to load the model into CPU first.
87
+ # Cache device slice config since device config cannot be deepcopied
88
+ modified_slice_config = False
89
+ if hasattr(
90
+ self.vllm_config.device_config,
91
+ 'slice') and self.vllm_config.device_config.slice is not None:
92
+ slice_config = self.vllm_config.device_config.slice
93
+ modified_slice_config = True
94
+ self.vllm_config.device_config.slice = None
95
+ self.vllm_config.compilation_config.static_forward_context.clear()
96
+
97
+ vllm_config_for_load = copy.deepcopy(self.vllm_config)
98
+ if modified_slice_config:
99
+ self.vllm_config.device_config.slice = slice_config
100
+ assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
101
+ vllm_config_for_load.device_config.device = "cpu"
102
+ # Clearing the cached compilation config, otherwise vllm model init will fail
103
+
104
+ # When expert parallelism is enabled, vLLM loads weight in sharding
105
+ # aware manner. Since tpu-inference has its own sharding logic, this
106
+ # may casue errors. Therefore, we disable it during weight loading.
107
+ vllm_config_for_load.parallel_config.enable_expert_parallel = False
108
+
109
+ use_random_weights = (
110
+ vllm_config_for_load.load_config.load_format == "dummy")
111
+ if use_random_weights:
112
+ logger.info(
113
+ "Initializing vLLM model with random weights, weight loading skipped."
114
+ )
115
+ # The DummyModelLoader in vLLM calls torch._sync for torch_xla path when
116
+ # it detects the tpu platform, but we don't need it and it causes crash
117
+ # without proper setup.
118
+ load_context = patch(
119
+ "torch._sync",
120
+ return_value=None) if use_random_weights else nullcontext()
121
+
122
+ # By default load weights to the CPU device first. If we are running
123
+ # under Pathways, this would cause weights to be loaded on a CPU-only
124
+ # node, so we'll need to remove this context.
125
+ jax_context = jax.default_device(
126
+ jax.devices("cpu")
127
+ [0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
128
+
129
+ # Load the vLLM model and wrap it into a new model whose forward
130
+ # function can calculate the hidden_state and logits.
131
+ with load_context, jax_context:
132
+ vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
133
+ lora_manager = None
134
+ if vllm_config_for_load.lora_config is not None:
135
+ # Replace layers in the model with LoRA layers.
136
+ with torchax.default_env():
137
+ # Argument "device" in load_lora_model is used to set the device
138
+ # used in punica wrapper.
139
+ lora_manager, vllm_model = load_lora_model(
140
+ vllm_model, vllm_config_for_load, device="jax")
141
+ replace_set_lora(vllm_model)
142
+
143
+ static_forward_context = vllm_config_for_load.compilation_config.static_forward_context
144
+ self.vllm_config.compilation_config.static_forward_context = static_forward_context
145
+
146
+ self.model = _VllmRunner(vllm_model)
147
+ params_and_buffers = shard_model_to_tpu(self.model, self.mesh)
148
+
149
+ # Returning to the jax land, so we need to wrap it into a JaxValue.
150
+ return jax_view(params_and_buffers), lora_manager
151
+
152
+ def jit_step_func(self):
153
+
154
+ @functools.partial(
155
+ jax.jit,
156
+ donate_argnames=("kv_caches", ),
157
+ compiler_options={
158
+ "xla_tpu_all_gather_collective_matmul_mode":
159
+ "post_spmd_conservative",
160
+ "xla_tpu_reduce_scatter_collective_matmul_mode":
161
+ "post_spmd_conservative"
162
+ },
163
+ static_argnames=("layer_name_to_kvcache_index", "is_first_rank",
164
+ "is_last_rank"),
165
+ )
166
+ def step_fun(
167
+ params_and_buffers, # This has been wrapped into torchax TorchValue
168
+ kv_caches: List[jax.Array],
169
+ input_ids: jax.Array,
170
+ attn_metadata: AttentionMetadata,
171
+ input_embeds: jax.Array,
172
+ input_positions: jax.Array,
173
+ layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
174
+ lora_metadata,
175
+ intermediate_tensors: JaxIntermediateTensors = None,
176
+ is_first_rank: bool = True,
177
+ is_last_rank: bool = True,
178
+ *args,
179
+ ) -> Tuple[List[jax.Array], jax.Array]:
180
+ layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index)
181
+ lora_metadata = torch_view(lora_metadata)
182
+ with torchax.default_env(), set_vllm_model_wrapper_context(
183
+ kv_caches=kv_caches,
184
+ mesh=self.mesh,
185
+ layer_name_to_kvcache_index=layer_name_to_kvcache_index
186
+ ), set_forward_context(attn_metadata=attn_metadata,
187
+ vllm_config=self.vllm_config):
188
+ # We need to wrap args from jax land into TorchValue with
189
+ # torch_view in order to call the Torch function.
190
+ original_lora_metadata = replace_lora_metadata(
191
+ self.model, lora_metadata, self.vllm_config.lora_config)
192
+ if not is_first_rank:
193
+ intermediate_tensors = intermediate_tensors.to_torch()
194
+ output_from_torch = torch.func.functional_call(
195
+ self.model,
196
+ torch_view(params_and_buffers),
197
+ kwargs={
198
+ "input_ids": torch_view(input_ids),
199
+ "positions": torch_view(input_positions),
200
+ "intermediate_tensors": None,
201
+ "inputs_embeds": None,
202
+ },
203
+ tie_weights=False,
204
+ )
205
+ replace_lora_metadata(self.model, original_lora_metadata,
206
+ self.vllm_config.lora_config)
207
+ vllm_model_wrapper_context = get_vllm_model_wrapper_context()
208
+ new_kv_caches = vllm_model_wrapper_context.kv_caches
209
+ # Wrap the output(hidden states or intermediate tensor)
210
+ # from torch land into a JaxValue for the jax code to consume.
211
+ if not is_last_rank:
212
+ output = JaxIntermediateTensors.from_torch(output_from_torch)
213
+ else:
214
+ output = jax_view(output_from_torch)
215
+ return new_kv_caches, output, []
216
+
217
+ return step_fun
218
+
219
+ def jit_compute_logits_func(self):
220
+
221
+ @functools.partial(
222
+ jax.jit,
223
+ out_shardings=(NamedSharding(self.mesh,
224
+ PartitionSpec("data", "model"))),
225
+ )
226
+ def compute_logits_func(
227
+ params_and_buffers: Any,
228
+ hidden_states: jax.Array,
229
+ lora_metadata,
230
+ ) -> jax.Array:
231
+ lora_metadata = torch_view(lora_metadata)
232
+ with torchax.default_env(), set_vllm_model_wrapper_context(
233
+ kv_caches=None, mesh=self.mesh):
234
+ original_lora_metadata = replace_lora_metadata(
235
+ self.model, lora_metadata, self.vllm_config.lora_config)
236
+ logits = torch.func.functional_call(
237
+ self.model,
238
+ torch_view(params_and_buffers),
239
+ kwargs={
240
+ "hidden_state": torch_view(hidden_states),
241
+ },
242
+ tie_weights=False,
243
+ )
244
+ replace_lora_metadata(self.model, original_lora_metadata,
245
+ self.vllm_config.lora_config)
246
+ return jax_view(logits)
247
+
248
+ return compute_logits_func
249
+
250
+
251
+ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
252
+ device: str) -> torch.nn.Module:
253
+ if not supports_lora(model):
254
+ raise ValueError(
255
+ f"{model.__class__.__name__} does not support LoRA yet.")
256
+
257
+ if supports_multimodal(model):
258
+ logger.warning("Regarding multimodal models, vLLM currently "
259
+ "only supports adding LoRA to language model.")
260
+
261
+ # Add LoRA Manager to the Model Runner
262
+ lora_manager = LRUCacheWorkerLoRAManager(
263
+ vllm_config,
264
+ device,
265
+ model.embedding_modules,
266
+ )
267
+ return lora_manager, lora_manager.create_lora_manager(model)
268
+
269
+
270
+ # The reason why replace the method is that the set_lora and reset_lora need to
271
+ # run under torchax env.
272
+ def replace_set_lora(model):
273
+
274
+ def _tpu_set_lora(
275
+ self,
276
+ index: int,
277
+ lora_a: torch.Tensor,
278
+ lora_b: torch.Tensor,
279
+ ):
280
+ with torchax.default_env():
281
+ self._original_set_lora(index, lora_a, lora_b)
282
+
283
+ def _tpu_reset_lora(self, index: int):
284
+ with torchax.default_env():
285
+ self._original_reset_lora(index)
286
+
287
+ for _, module in model.named_modules():
288
+ if isinstance(module, BaseLayerWithLoRA):
289
+ module._original_set_lora = module.set_lora
290
+ module._original_reset_lora = module.reset_lora
291
+ module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
292
+ module.reset_lora = _tpu_reset_lora.__get__(
293
+ module, module.__class__)
@@ -0,0 +1,45 @@
1
+ from contextlib import contextmanager
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional
4
+
5
+ import jax
6
+ from jax.sharding import Mesh
7
+
8
+
9
+ @dataclass
10
+ class VllmModelWrapperContext:
11
+ kv_caches: List[jax.Array]
12
+ mesh: Mesh
13
+ layer_name_to_kvcache_index: Dict[str, int]
14
+
15
+
16
+ _vllm_model_wrapper_context: Optional[VllmModelWrapperContext] = None
17
+
18
+
19
+ def get_vllm_model_wrapper_context() -> VllmModelWrapperContext:
20
+ assert _vllm_model_wrapper_context is not None, (
21
+ "VllmModelWrapperContext is not set. "
22
+ "Please use `set_vllm_model_wrapper_context` to set the VllmModelWrapperContext."
23
+ )
24
+ return _vllm_model_wrapper_context
25
+
26
+
27
+ @contextmanager
28
+ def set_vllm_model_wrapper_context(
29
+ *,
30
+ kv_caches: List[jax.Array],
31
+ mesh: Mesh,
32
+ layer_name_to_kvcache_index: Dict[str, int] = None,
33
+ ):
34
+ global _vllm_model_wrapper_context
35
+ prev_context = _vllm_model_wrapper_context
36
+ _vllm_model_wrapper_context = VllmModelWrapperContext(
37
+ kv_caches=kv_caches,
38
+ mesh=mesh,
39
+ layer_name_to_kvcache_index=layer_name_to_kvcache_index,
40
+ )
41
+
42
+ try:
43
+ yield
44
+ finally:
45
+ _vllm_model_wrapper_context = prev_context
@@ -0,0 +1,2 @@
1
+ # ruff: noqa
2
+ from tpu_inference.platforms.tpu_platform import TpuPlatform
@@ -0,0 +1,275 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
4
+
5
+ import jax.numpy as jnp
6
+ import torch
7
+ import vllm.envs as vllm_envs
8
+ from tpu_info import device
9
+ from vllm.inputs import ProcessorInputs, PromptType
10
+ from vllm.platforms.interface import Platform, PlatformEnum
11
+ from vllm.sampling_params import SamplingParams, SamplingType
12
+
13
+ from tpu_inference import envs
14
+ from tpu_inference.layers.common.sharding import ShardingConfigManager
15
+ from tpu_inference.logger import init_logger
16
+ from tpu_inference.utils import to_jax_dtype, to_torch_dtype
17
+
18
+ if TYPE_CHECKING:
19
+ from vllm.attention.backends.registry import AttentionBackendEnum
20
+ from vllm.config import BlockSize, ModelConfig, VllmConfig
21
+ from vllm.pooling_params import PoolingParams
22
+ else:
23
+ BlockSize = None
24
+ ModelConfig = None
25
+ VllmConfig = None
26
+ PoolingParams = None
27
+ AttentionBackendEnum = None
28
+
29
+ logger = init_logger(__name__)
30
+
31
+
32
+ class TpuPlatform(Platform):
33
+ _enum = PlatformEnum.TPU
34
+ device_name: str = "tpu"
35
+ device_type: str = "tpu"
36
+ dispatch_key: str = "XLA"
37
+ ray_device_key: str = "TPU"
38
+ device_control_env_var: str = "TPU_VISIBLE_CHIPS"
39
+ simple_compile_backend: str = "openxla"
40
+
41
+ supported_quantization: list[str] = [
42
+ "tpu_int8", "compressed-tensors", "awq", "fp8", "mxfp4"
43
+ ]
44
+
45
+ additional_env_vars: list[str] = [
46
+ "PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
47
+ "TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE"
48
+ ]
49
+
50
+ @classmethod
51
+ def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
52
+ head_size: int, dtype: jnp.dtype,
53
+ kv_cache_dtype: Optional[str], block_size: int,
54
+ use_v1: bool, use_mla: bool, has_sink: bool,
55
+ use_sparse: bool, attn_type: Any) -> str:
56
+ from vllm.attention.backends.registry import AttentionBackendEnum
57
+ if selected_backend != AttentionBackendEnum.PALLAS:
58
+ logger.info("Cannot use %s backend on TPU.", selected_backend)
59
+
60
+ if use_v1:
61
+ logger.info("Using Pallas V1 backend.")
62
+ return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
63
+ else:
64
+ logger.info("Using Pallas backend.")
65
+ return "vllm.attention.backends.pallas.PallasAttentionBackend"
66
+
67
+ @classmethod
68
+ def get_device_name(cls, device_id: int = 0) -> str:
69
+ try:
70
+ if vllm_envs.VLLM_TPU_USING_PATHWAYS:
71
+ # Causes mutliprocess accessing IFRT when calling jax.devices()
72
+ return "TPU v6 lite"
73
+ else:
74
+ chip_type, _ = device.get_local_chips()
75
+ return f"TPU {chip_type.name}"
76
+ except Exception as e:
77
+ logger.warning(f"Error getting device name: {e}")
78
+ return 'TPU'
79
+
80
+ @classmethod
81
+ def fp8_dtype(cls) -> torch.dtype:
82
+ if cls.get_device_name().lower() == "tpu v6e":
83
+ logger.info(
84
+ "Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
85
+ return torch.float8_e5m2
86
+ return torch.float8_e4m3fn
87
+
88
+ @classmethod
89
+ def get_device_total_memory(cls, device_id: int = 0) -> int:
90
+ raise NotImplementedError
91
+
92
+ @classmethod
93
+ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
94
+ return False
95
+
96
+ @classmethod
97
+ def get_punica_wrapper(cls) -> str:
98
+ return "tpu_inference.lora.torch_punica_tpu.PunicaWrapperTPU"
99
+
100
+ @classmethod
101
+ def get_infinity_values(cls, dtype: jnp.dtype) -> Tuple[float, float]:
102
+ return jnp.finfo(dtype).min, jnp.finfo(dtype).max
103
+
104
+ @classmethod
105
+ def can_update_inplace(cls):
106
+ return False
107
+
108
+ @classmethod
109
+ def get_lora_vocab_padding_size(cls) -> int:
110
+ return 1
111
+
112
+ @classmethod
113
+ def inference_mode(cls):
114
+ return True
115
+
116
+ @classmethod
117
+ def _initialize_sharding_config(cls, vllm_config: VllmConfig) -> None:
118
+
119
+ sharding_config = ShardingConfigManager.from_vllm_config(vllm_config)
120
+ vllm_config.sharding_config = sharding_config
121
+ logger.info(f"Initialized sharding configuration: {sharding_config}")
122
+
123
+ @classmethod
124
+ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
125
+
126
+ if vllm_envs.VLLM_TPU_USING_PATHWAYS:
127
+ assert not vllm_envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
128
+ "VLLM_ENABLE_V1_MULTIPROCESSING must be 0 when using Pathways(JAX_PLATFORMS=proxy)"
129
+ )
130
+ cls._initialize_sharding_config(vllm_config)
131
+
132
+ from vllm.config import CompilationMode
133
+
134
+ cache_config = vllm_config.cache_config
135
+ # For v0, the default block size is 16.
136
+ if cache_config and cache_config.block_size is None:
137
+ cache_config.block_size = cast(BlockSize, 16)
138
+
139
+ compilation_config = vllm_config.compilation_config
140
+
141
+ # TPU only supports DYNAMO_TRACE_ONCE compilation level
142
+ # NOTE(xiang): the compilation_config is not used by jax.
143
+ if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
144
+ compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
145
+
146
+ if compilation_config.backend == "":
147
+ compilation_config.backend = "openxla"
148
+
149
+ # If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
150
+ impl = envs.MODEL_IMPL_TYPE
151
+
152
+ # NOTE(xiang): convert dtype to jnp.dtype
153
+ # NOTE(wenlong): skip this logic for mm model preprocessing
154
+ # For mm model preprocessors, it may need the output dtype to be torch.
155
+ # In order to avoid a PR to vLLM, we postpone the dtype checking during
156
+ # tpu_worker initialization
157
+ if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
158
+ model_dtype = vllm_config.model_config.dtype
159
+ try:
160
+ dtype = to_jax_dtype(model_dtype)
161
+ except ValueError:
162
+ logger.warning(f"{model_dtype=} is not supported. "
163
+ "Falling back to jnp.bfloat16")
164
+ dtype = jnp.bfloat16
165
+ if impl == "vllm":
166
+ dtype = to_torch_dtype(dtype)
167
+ vllm_config.model_config.dtype = dtype
168
+
169
+ # TODO(cuiq): remove this dependency.
170
+ from vllm.v1.attention.backends.pallas import PallasAttentionBackend
171
+ cache_config.block_size = PallasAttentionBackend.get_page_size(
172
+ vllm_config) # type: ignore[assignment]
173
+ min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config)
174
+ if min_page_size > cache_config.block_size:
175
+ logger.warning(
176
+ "Increase the page size from %s to %s to make sure there's"
177
+ "no SMEM OOM",
178
+ cache_config.block_size,
179
+ min_page_size,
180
+ )
181
+ cache_config.block_size = min_page_size # type: ignore[assignment]
182
+
183
+ parallel_config = vllm_config.parallel_config
184
+ scheduler_config = vllm_config.scheduler_config
185
+ parallel_config.worker_cls = \
186
+ "tpu_inference.worker.tpu_worker.TPUWorker"
187
+
188
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
189
+ if not multihost_backend: # Single host
190
+ if parallel_config.pipeline_parallel_size == 1:
191
+ logger.info("Force using UniProcExecutor for JAX on \
192
+ single host without pipeline parallelism.")
193
+ parallel_config.distributed_executor_backend = "uni"
194
+ else:
195
+ logger.info("Force using MultiprocExecutor for JAX on \
196
+ single host with pipeline parallelism.")
197
+ parallel_config.distributed_executor_backend = "mp"
198
+ elif multihost_backend == "ray":
199
+ from tpu_inference.executors.ray_distributed_executor import \
200
+ RayDistributedExecutor
201
+ parallel_config.distributed_executor_backend = RayDistributedExecutor
202
+ logger.info(
203
+ "Force using RayDistributedExecutor for JAX on multihost.")
204
+ else:
205
+ logger.warning(
206
+ f"Unknown TPU multihost backend: {multihost_backend}. "
207
+ "Using uniproc_executor.")
208
+ parallel_config.distributed_executor_backend = "uni"
209
+
210
+ if scheduler_config.is_multimodal_model and not \
211
+ scheduler_config.disable_chunked_mm_input:
212
+ logger.warning("TPU does not support running Multimodal models"\
213
+ " without setting `--disable_chunked_mm_input`. " \
214
+ "Forcing --disable_chunked_mm_input.")
215
+ scheduler_config.disable_chunked_mm_input = True
216
+
217
+ kv_transfer_config = vllm_config.kv_transfer_config
218
+ if kv_transfer_config is not None:
219
+ assert kv_transfer_config.kv_connector == "TPUConnector"
220
+ # Late initialization to avoid circular import
221
+ from tpu_inference.models.jax.utils.quantization.quantization_utils import \
222
+ update_vllm_config_for_qwix_quantization
223
+
224
+ update_vllm_config_for_qwix_quantization(vllm_config)
225
+
226
+ from tpu_inference.core.sched.dp_scheduler import \
227
+ update_vllm_config_for_dp_scheduler
228
+ update_vllm_config_for_dp_scheduler(vllm_config)
229
+
230
+ @classmethod
231
+ def is_pin_memory_available(cls):
232
+ logger.warning("Pin memory is not supported on TPU.")
233
+ return False
234
+
235
+ @classmethod
236
+ def get_device_communicator_cls(cls) -> str:
237
+ return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
238
+
239
+ @classmethod
240
+ def use_all_gather(cls) -> bool:
241
+ return True
242
+
243
+ @classmethod
244
+ def supports_v1(cls, model_config: ModelConfig) -> bool:
245
+ # V1 support on TPU is experimental
246
+ return True
247
+
248
+ @classmethod
249
+ def validate_request(
250
+ cls,
251
+ prompt: PromptType,
252
+ params: Union[SamplingParams, PoolingParams],
253
+ processed_inputs: ProcessorInputs,
254
+ ) -> None:
255
+ """Raises if this request is unsupported on this platform"""
256
+
257
+ if isinstance(params, SamplingParams):
258
+ if params.sampling_type == SamplingType.RANDOM_SEED:
259
+ raise ValueError("JAX does not support per-request seed.")
260
+
261
+ @classmethod
262
+ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
263
+ model_config: ModelConfig) -> bool:
264
+ return True
265
+
266
+ @classmethod
267
+ def use_sync_weight_loader(cls) -> bool:
268
+ """
269
+ Returns if the current platform needs to sync weight loader.
270
+ """
271
+ return True
272
+
273
+ @classmethod
274
+ def support_hybrid_kv_cache(cls) -> bool:
275
+ return True
File without changes