tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,321 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ import tempfile
5
+ from typing import Callable, Dict, Optional, Tuple, Union
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import jaxtyping
10
+ import vllm.envs as envs
11
+ from vllm.config import VllmConfig, set_current_vllm_config
12
+ from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
13
+ has_kv_transfer_group)
14
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
15
+ init_distributed_environment)
16
+ from vllm.lora.request import LoRARequest
17
+ from vllm.tasks import SupportedTask
18
+ from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
19
+ from vllm.v1.core.sched.output import SchedulerOutput
20
+ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
21
+ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
22
+
23
+ from tpu_inference import utils
24
+ from tpu_inference.di.abstracts import (AbstractKVCacheConfig,
25
+ AbstractLoRARequest,
26
+ AbstractSchedulerOutput)
27
+ from tpu_inference.di.interfaces import HostInterface
28
+ from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
29
+ get_node_id)
30
+ from tpu_inference.logger import init_logger
31
+ from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
32
+ from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
33
+ from tpu_inference.worker._temporary_vllm_compat import (
34
+ adapt_kv_cache_config_if_needed, adapt_lora_request_if_needed,
35
+ adapt_scheduler_output_if_needed)
36
+ from tpu_inference.worker.base import AbstractTpuWorker
37
+
38
+ logger = init_logger(__name__)
39
+
40
+ _DTYPE: dict[str, jnp.dtype] = {
41
+ "bfloat16": jnp.bfloat16,
42
+ "float": jnp.float32,
43
+ "float32": jnp.float32,
44
+ }
45
+
46
+
47
+ class TPUWorker(AbstractTpuWorker):
48
+
49
+ def __init__(self,
50
+ vllm_config: VllmConfig,
51
+ local_rank: int,
52
+ rank: int,
53
+ distributed_init_method: str,
54
+ is_driver_worker: bool = False,
55
+ devices=None,
56
+ host_interface: Optional[HostInterface] = None):
57
+ super().__init__(host_interface)
58
+
59
+ # If we use vLLM's model implementation in PyTorch, we should set it
60
+ # with torch version of the dtype.
61
+ impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
62
+ if impl != "vllm": # vllm-pytorch implementation does not need this conversion
63
+
64
+ # NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
65
+ if not isinstance(vllm_config.model_config.dtype, str):
66
+ logger.warning(
67
+ "The model dtype is not properly set for JAX backend. "
68
+ "Overwriting it to jnp.bfloat16")
69
+ vllm_config.model_config.dtype = jnp.bfloat16
70
+ else:
71
+ vllm_config.model_config.dtype = _DTYPE.get(
72
+ vllm_config.model_config.dtype, jnp.bfloat16)
73
+
74
+ self.vllm_config = vllm_config
75
+ self.model_config = vllm_config.model_config
76
+ self.parallel_config = vllm_config.parallel_config
77
+ self.cache_config = vllm_config.cache_config
78
+ self.local_rank = local_rank
79
+ self.rank = rank
80
+ self.distributed_init_method = distributed_init_method
81
+ self.is_driver_worker = is_driver_worker
82
+ self.devices = devices if devices is not None else []
83
+
84
+ if self.model_config.trust_remote_code:
85
+ # note: lazy import to avoid importing torch before initializing
86
+ from vllm.utils import init_cached_hf_modules
87
+
88
+ init_cached_hf_modules()
89
+
90
+ # Delay profiler initialization to the start of the profiling.
91
+ # This is because in vLLM V1, MP runtime is initialized before the
92
+ # TPU Worker is initialized. The profiler server needs to start after
93
+ # MP runtime is initialized.
94
+ self.profile_dir = None
95
+ if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
96
+ # For TPU, we can only have 1 active profiler session for 1 profiler
97
+ # server. So we only profile on rank0.
98
+ self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
99
+ logger.info("Profiling enabled. Traces will be saved to: %s",
100
+ self.profile_dir)
101
+
102
+ use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
103
+ # Only one instance of profiler is allowed
104
+ if use_jax_profiler_server and jax.devices()[0] == self.devices[0]:
105
+ jax_profiler_server_port = int(
106
+ os.getenv("JAX_PROFILER_SERVER_PORT", 9999))
107
+ logger.info(
108
+ f"Starting JAX profiler server on port {jax_profiler_server_port}"
109
+ )
110
+ jax.profiler.start_server(jax_profiler_server_port)
111
+
112
+ def initialize_cache(self, num_gpu_blocks: int,
113
+ num_cpu_blocks: int) -> None:
114
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
115
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
116
+
117
+ def init_device(self):
118
+ if not self.devices:
119
+ try:
120
+ device_indexes = self.vllm_config.additional_config[
121
+ "sharding"]["sharding_strategy"]["device_indexes"]
122
+ self.devices = [jax.devices()[i] for i in device_indexes]
123
+ except KeyError:
124
+ tp = self.parallel_config.tensor_parallel_size
125
+ self.devices = jax.devices()[:tp]
126
+
127
+ # Initialize the vLLM distribution layer as a single chip environment,
128
+ # we'll swap the model's parallel modules with TPU SPMD equivalents.
129
+ with set_current_vllm_config(self.vllm_config):
130
+ temp_file = tempfile.mkstemp()[1]
131
+ init_distributed_environment(
132
+ world_size=1,
133
+ rank=0,
134
+ local_rank=0,
135
+ distributed_init_method=f"file://{temp_file}",
136
+ backend="gloo",
137
+ )
138
+ ensure_model_parallel_initialized(
139
+ tensor_model_parallel_size=1,
140
+ pipeline_model_parallel_size=1,
141
+ )
142
+ ensure_kv_transfer_initialized(self.vllm_config)
143
+ self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
144
+ logger.info(f"Init worker | "
145
+ f"rank={self.rank} | "
146
+ f"node_id={get_node_id()} | "
147
+ f"is_driver_worker={self.is_driver_worker} | "
148
+ f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
149
+
150
+ def determine_available_memory(self) -> int:
151
+ gpu_memory_utilization = self.cache_config.gpu_memory_utilization
152
+ hbm_usage = utils.hbm_usage_bytes(self.devices)
153
+ total_hbm_limit = total_hbm_used = 0
154
+ for used, limit in hbm_usage:
155
+ total_hbm_used += used
156
+ total_hbm_limit += limit
157
+
158
+ total_hbm_limit_cap = total_hbm_limit * gpu_memory_utilization
159
+ total_hbm_avail = int(total_hbm_limit_cap - total_hbm_used)
160
+
161
+ total_hbm_limit_gb = round(total_hbm_limit / utils.GBYTES, 2)
162
+ total_hbm_limit_cap_gb = round(total_hbm_limit_cap / utils.GBYTES, 2)
163
+ total_hbm_used_gb = round(total_hbm_used / utils.GBYTES, 2)
164
+ total_hbm_avail_gb = round(total_hbm_avail / utils.GBYTES, 2)
165
+
166
+ logger.info(f"Memory statistics | "
167
+ f"{total_hbm_limit_gb=}GiB | "
168
+ f"{total_hbm_limit_cap_gb=}GiB | "
169
+ f"{total_hbm_used_gb=}GiB | "
170
+ f"{total_hbm_avail_gb=}GiB")
171
+
172
+ if total_hbm_avail <= 0:
173
+ raise ValueError(f"{total_hbm_used_gb=}GiB exceeds "
174
+ f"{total_hbm_limit_cap_gb=}GiB by "
175
+ f"{-total_hbm_avail_gb}GiB. Please consider "
176
+ f"increasing --gpu-memory-utilization from "
177
+ f"{gpu_memory_utilization} to a larger value.")
178
+ return total_hbm_avail
179
+
180
+ def execute_model(
181
+ self,
182
+ scheduler_output: Union[AbstractSchedulerOutput, SchedulerOutput],
183
+ ) -> Optional[ModelRunnerOutput]:
184
+ # NOTE: This method intentionally returns a concrete vLLM type, which
185
+ # violates the pure abstract contract of the base class. This is a
186
+ # deliberate, temporary compromise for the same reasons outlined in
187
+ # the `get_kv_cache_spec` method.
188
+
189
+ # Adapt the input if necessary (temporary compatibility layer)
190
+ adapted_scheduler_output = adapt_scheduler_output_if_needed(
191
+ scheduler_output)
192
+
193
+ # Unwrap the adapter to get the concrete vLLM object
194
+ vllm_scheduler_output = adapted_scheduler_output.vllm_scheduler_output
195
+ output = self.model_runner.execute_model(vllm_scheduler_output)
196
+
197
+ # With a connector, the scheduler expects output from all workers
198
+ if has_kv_transfer_group():
199
+ return output
200
+
201
+ return output if self.is_driver_worker else None
202
+
203
+ def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
204
+ return self.model_runner.take_draft_token_ids()
205
+
206
+ def add_lora(
207
+ self,
208
+ lora_request: Union[AbstractLoRARequest, LoRARequest],
209
+ ) -> bool:
210
+ # Adapt the input if necessary (temporary compatibility layer)
211
+ adapted_lora_request = adapt_lora_request_if_needed(lora_request)
212
+
213
+ # Unwrap the adapter to get the concrete vLLM object
214
+ vllm_lora_request = adapted_lora_request.vllm_lora_request # noqa: F841
215
+
216
+ raise NotImplementedError(
217
+ "LoRA is not supported by the JAX worker yet.")
218
+
219
+ def profile(self, is_start: bool = True):
220
+ if is_start:
221
+ options = jax.profiler.ProfileOptions()
222
+ options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
223
+ jax.profiler.start_trace(self.profile_dir,
224
+ profiler_options=options)
225
+ else:
226
+ jax.profiler.stop_trace()
227
+
228
+ def load_model(self) -> None:
229
+ self.model_runner.load_model()
230
+
231
+ def compile_or_warm_up_model(self) -> None:
232
+ self.model_runner.capture_model()
233
+ # Reset the seed to ensure that the random state is not affected by
234
+ # the model initialization and profiling.
235
+ self.model_runner._init_random()
236
+
237
+ def reset_mm_cache(self) -> None:
238
+ pass
239
+
240
+ def get_model(self):
241
+ return self.model_runner.get_model()
242
+
243
+ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
244
+ return self.model_runner.get_supported_tasks()
245
+
246
+ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
247
+ # NOTE: This method intentionally returns a concrete vLLM type, which
248
+ # violates the pure abstract contract of the base class. This is a
249
+ # deliberate, temporary compromise.
250
+ #
251
+ # The vLLM executor that calls this method expects the concrete
252
+ # `vllm.KVCacheSpec` object to perform its own internal logic. If we
253
+ # returned an abstract adapter, the vLLM code would break.
254
+ #
255
+ # The ideal long-term solution is for the vLLM DI container to be
256
+ # responsible for this translation. When vLLM can be modified, this
257
+ # method should be changed to return `dict[str, AbstractKVCacheSpec]`,
258
+ # and the vLLM side should be updated to handle the translation.
259
+ kv_cache_specs = self.model_runner.get_kv_cache_spec()
260
+
261
+ if len(kv_cache_specs) == 0:
262
+ return kv_cache_specs
263
+
264
+ # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
265
+ # feature that allows overriding page_size_bytes of KVCacheSpec.
266
+ vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
267
+ rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
268
+ kv_cache_specs)
269
+
270
+ if vllm_page_size_bytes != rpa_page_size_bytes:
271
+ logger.info(
272
+ f"KV cache page size calculated by vLLM "
273
+ f"({vllm_page_size_bytes} Bytes) does not match with actual "
274
+ f"page size used by RPA kernel ({rpa_page_size_bytes} Bytes). "
275
+ f"Recalculating number of KV blocks using actual page size.")
276
+
277
+ available_memory = self.determine_available_memory()
278
+ num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
279
+ available_memory, rpa_page_size_bytes)
280
+
281
+ cache_config = self.vllm_config.cache_config
282
+ cache_config.num_gpu_blocks_override = num_blocks
283
+
284
+ return kv_cache_specs
285
+
286
+ def initialize_from_config(
287
+ self,
288
+ kv_cache_config: Union[AbstractKVCacheConfig, KVCacheConfig],
289
+ ) -> None:
290
+ """Allocate GPU KV cache with the specified kv_cache_config."""
291
+ adapted_kv_cache_config = adapt_kv_cache_config_if_needed(
292
+ kv_cache_config)
293
+ vllm_kv_cache_config = adapted_kv_cache_config.vllm_kv_cache_config
294
+ self.model_runner.initialize_kv_cache(vllm_kv_cache_config)
295
+
296
+ def get_node_kv_ip_port(self) -> tuple[int, str, int]:
297
+ node_id = get_node_id()
298
+ ip = get_host_ip()
299
+ port = get_kv_transfer_port()
300
+ return (int(node_id), ip, int(port))
301
+
302
+ def check_health(self) -> None:
303
+ # worker will always be healthy as long as it's running.
304
+ return
305
+
306
+ def sync_weights(
307
+ self,
308
+ updated_weights: jaxtyping.PyTree,
309
+ mappings: Dict[str, Tuple[str, Tuple[str]]],
310
+ transpose_keys: Dict[str, Tuple[int]],
311
+ reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree],
312
+ jaxtyping.PyTree] = None
313
+ ) -> None:
314
+ """Sync the updated weights to the model runner."""
315
+ return self.model_runner._sync_weights(updated_weights=updated_weights,
316
+ mappings=mappings,
317
+ transpose_keys=transpose_keys,
318
+ reshard_fn=reshard_fn)
319
+
320
+ def shutdown(self) -> None:
321
+ return
@@ -0,0 +1,101 @@
1
+ Metadata-Version: 2.4
2
+ Name: tpu_inference
3
+ Version: 0.11.1
4
+ Author: tpu_inference Contributors
5
+ Classifier: Development Status :: 3 - Alpha
6
+ Classifier: Intended Audience :: Developers
7
+ Classifier: Intended Audience :: Education
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: License :: OSI Approved :: Apache Software License
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Requires-Python: >=3.10
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: tpu-info==0.4.0
18
+ Requires-Dist: yapf==0.43.0
19
+ Requires-Dist: pytest
20
+ Requires-Dist: pytest-mock
21
+ Requires-Dist: absl-py
22
+ Requires-Dist: numpy
23
+ Requires-Dist: google-cloud-storage
24
+ Requires-Dist: jax==0.7.2
25
+ Requires-Dist: jaxlib==0.7.2
26
+ Requires-Dist: libtpu==0.0.23
27
+ Requires-Dist: jaxtyping
28
+ Requires-Dist: flax==0.11.1
29
+ Requires-Dist: torchax==0.0.7
30
+ Requires-Dist: qwix==0.1.1
31
+ Requires-Dist: torchvision==0.23.0
32
+ Requires-Dist: pathwaysutils
33
+ Requires-Dist: parameterized
34
+ Dynamic: author
35
+ Dynamic: classifier
36
+ Dynamic: description
37
+ Dynamic: description-content-type
38
+ Dynamic: license-file
39
+ Dynamic: requires-dist
40
+ Dynamic: requires-python
41
+
42
+ <p align="center">
43
+ <!-- This image will ONLY show up in GitHub's dark mode -->
44
+ <img src="docs/assets/tpu_inference_dark_mode_short.png#gh-dark-mode-only" alt="vLLM TPU" style="width: 86%;">
45
+ <!-- This image will ONLY show up in GitHub's light mode (and on other platforms) -->
46
+ <img src="docs/assets/tpu_inference_light_mode_short.png#gh-light-mode-only" alt="vLLM TPU" style="width: 86%;">
47
+ </p>
48
+
49
+ <p align="center">
50
+ | <a href="https://tpu.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://discuss.vllm.ai/c/hardware-support/google-tpu-support/27"><b>User Forum</b></a> | <a href="https://join.slack.com/share/enQtOTY2OTUxMDIyNjY1OS00M2MxYWQwZjAyMGZjM2MyZjRjNTA0ZjRkNjkzOTRhMzg0NDM2OTlkZDAxOTAzYmJmNzdkNDc4OGZjYTUwMmRh"><b>Developer Slack</b></a> |
51
+ </p>
52
+
53
+ ---
54
+
55
+ _Upcoming Events_ 🔥
56
+
57
+ - Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
58
+ - Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
59
+ - Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
60
+
61
+ _Latest News_ 🔥
62
+
63
+ - [2025/10] vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU
64
+
65
+ <details>
66
+ <summary><i>Previous News</i> 🔥</summary>
67
+
68
+ </details>
69
+
70
+ ---
71
+ ## About
72
+
73
+ vLLM TPU is now powered by `tpu-inference`, an expressive and powerful new hardware plugin unifying JAX and PyTorch under a single lowering path within the vLLM project. The new backend now provides a framework for developers to:
74
+
75
+ - Push the limits of TPU hardware performance in open source.
76
+ - Provide more flexibility to JAX and PyTorch users by running PyTorch model definitions performantly on TPU without any additional code changes, while also extending native support to JAX.
77
+ - Retain vLLM standardization: keep the same user experience, telemetry, and interface.
78
+
79
+ ## Recommended models and features
80
+
81
+ Although vLLM TPU’s new unified backend makes out-of-the-box high performance serving possible with any model supported in vLLM, the reality is that we're still in the process of implementing a few core components.
82
+
83
+ For this reason, we’ve provided a list of recommended [models](https://github.com/vllm-project/tpu-inference/blob/main/support_matrices/model_support_matrix.csv) and [features](https://github.com/vllm-project/tpu-inference/blob/main/support_matrices/feature_support_matrix.csv) that are validated for accuracy and stress-tested for performance.
84
+
85
+ ## Get started
86
+
87
+ Get started with vLLM on TPUs by following the [quickstart guide](https://github.com/vllm-project/tpu-inference/tree/main/docs/getting_started/quickstart.md).
88
+
89
+ Visit our [documentation](https://github.com/vllm-project/tpu-inference/tree/main/docs) to learn more.
90
+
91
+ ## Contribute
92
+
93
+ We're always looking for ways to partner with the community to accelerate vLLM TPU development. If you're interested in contributing to this effort, check out the [Contributing guide](https://github.com/vllm-project/tpu-inference/blob/main/CONTRIBUTING.md) and [Issues](https://github.com/vllm-project/tpu-inference/issues) to start. We recommend filtering Issues on the [**good first issue** tag](https://github.com/vllm-project/tpu-inference/issues?q=is%3Aissue+state%3Aopen+label%3A%22good+first+issue%22) if it's your first time contributing.
94
+
95
+ ## Contact us
96
+
97
+ - For technical questions and feature requests, open a GitHub [Issue](https://github.com/vllm-project/tpu-inference/issues)
98
+ - For feature requests, please open one on Github [here](https://github.com/vllm-project/tpu-inference/issues/new/choose)
99
+ - For discussing with fellow users, use the [TPU support topic in the vLLM Forum](https://discuss.vllm.ai/c/hardware-support/google-tpu-support/27)
100
+ - For coordinating contributions and development, use the [Developer Slack](https://join.slack.com/share/enQtOTY2OTUxMDIyNjY1OS00M2MxYWQwZjAyMGZjM2MyZjRjNTA0ZjRkNjkzOTRhMzg0NDM2OTlkZDAxOTAzYmJmNzdkNDc4OGZjYTUwMmRh)
101
+ - For collaborations and partnerships, contact us at [vllm-tpu@google.com](mailto:vllm-tpu@google.com)
@@ -0,0 +1,168 @@
1
+ tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ tests/test_base.py,sha256=Ct5WFRMHL7IHEIxk8FrzAvO8m0xFuDpzDBKkAKKAL2Q,7341
3
+ tests/test_quantization.py,sha256=tmHBwpAh1Lz4cSB15fwnvmbA1TZ_zM_I1iP99hhGaEk,34444
4
+ tests/test_tpu_info.py,sha256=ZrwlMsp8ffITkS_b8Q1t_QG-a-WVAd4NUcjHhGibcsI,4670
5
+ tests/test_utils.py,sha256=JFxlYnIddw8t096smLEs_PTycocVVzMGDBgZv5YUlnc,7763
6
+ tests/tpu_backend_test.py,sha256=1_rEUA2XGsDCbZVX5KFOQ00OyTF4YnKRtNmk6ctbKXc,2462
7
+ tests/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ tests/core/test_adapters.py,sha256=HcZHf0GTwfHtW1rhcvAb1A3ezejQpYzzMuJhUvIsDo4,2927
9
+ tests/core/test_core_tpu.py,sha256=n6IPk3VzaFYgm3LDeDp1qoKgRN5ysL7JidFOex2lIDg,22342
10
+ tests/core/test_disagg_executor.py,sha256=QdE2YZs08EyDDCmSjhiXkXqQ9BJTgO6csr_E1xkkfSg,2256
11
+ tests/core/test_disagg_utils.py,sha256=alktTGppaGdg-_un0Amz8Y0IDQz-xNJN0dXG-YApEmY,1955
12
+ tests/core/test_init.py,sha256=NEFI5A9eKGu4rmeJ2iqd0EmhlA3bzbVkXmMi1PV1b9U,1687
13
+ tests/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ tests/kernels/quantized_matmul_kernel_test.py,sha256=od5-zXFjcsc_gWGRDrREL8E_ftymNniQVTzgtkBo_Gc,5679
15
+ tests/kernels/ragged_kv_cache_update_v2_test.py,sha256=6-HjP5CoUG-kcuP8MS-JJVMiBnPRo_zadS3VInnO0D4,10821
16
+ tests/kernels/ragged_paged_attention_kernel_v2_test.py,sha256=pWqo9UYF0tzwgBKO_xYw-TYSPrtAsKcMK5Haj8hFG7I,11340
17
+ tests/kernels/ragged_paged_attention_kernel_v3_test.py,sha256=Hrd8iUkS1pS3rxeTyY53aYRg_ZL_d3NqgBXvOgnigSU,14838
18
+ tests/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
+ tests/lora/test_lora.py,sha256=nBbwnmvNgTWjqjKXJ0o_n5k7IksXMt5I9SDbpe6IsfM,4168
20
+ tpu_inference/__init__.py,sha256=5hJ_YCx4yQJ3HH2BruqWaOtnYi_IapS9no7l62foFFo,1096
21
+ tpu_inference/backend.py,sha256=V0DveQe4maWGz_hRD4bivwTXIQsANZkEj63_0m7U6nA,2552
22
+ tpu_inference/logger.py,sha256=HQCz7NefmbturuhOC7-3Ixbtcdgoz4g9FHh2RB6o8cc,334
23
+ tpu_inference/tpu_info.py,sha256=9UohshkndR6dZpGWpWXfTD4qvIVdVgHf0yOoSEkLTrw,2276
24
+ tpu_inference/utils.py,sha256=M1JMLFtd_5_za7XAQi2ENY8d7aRC-S7wbpYpLh42tyQ,9533
25
+ tpu_inference/adapters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
+ tpu_inference/adapters/vllm_adapters.py,sha256=n_iJ-BM4aGlnuf6Qhgye6u-H9dkzZP4SPufjspqw-dk,1412
27
+ tpu_inference/adapters/vllm_config_adapters.py,sha256=V9sNdkKYHJpK-OKaaMYXZZP-IhZW6MOe7fJSwQbJngE,4076
28
+ tpu_inference/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
+ tpu_inference/core/adapters.py,sha256=dTZV95MDUdORJbcdYf1JYNTnNDVmv38V22GE7hiQJmo,4484
30
+ tpu_inference/core/core_tpu.py,sha256=_iGVOp30qEGJX3MTYFRXRTsCdbdtd6vwtBtIeFA0sy8,32609
31
+ tpu_inference/core/disagg_executor.py,sha256=dM0cvw2uS-jDlfG4BtsmGAa6hKyhhQ1H-ZQVvn65Xb0,4597
32
+ tpu_inference/core/disagg_utils.py,sha256=ufWNFWQ5n4YnZpPOtoReHlYo4dlN7AbIqCyqS4an0t4,1572
33
+ tpu_inference/di/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
+ tpu_inference/di/abstracts.py,sha256=pMC-wD9aVoCmD3RXW5A4oHZ9Islu2C6huG9HYmQvxeY,541
35
+ tpu_inference/di/host.py,sha256=FKRd5Xs1BVzbXku8A35tZmJDwPVg_66drdpAdSzJ5VI,2601
36
+ tpu_inference/di/interfaces.py,sha256=LFlfXHWK61apIBb2nEBNjuAsdLLmnxTtrkVGslEKTj8,1524
37
+ tpu_inference/distributed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
+ tpu_inference/distributed/tpu_connector.py,sha256=l_5l44BVIIClz4hrv5kWtctoUELHtvEXdqfypXlQh3I,28499
39
+ tpu_inference/distributed/utils.py,sha256=8AOevmxJi7o9hLXyAydcYh-WaWGS6-BKJpV8kW6-P6E,1494
40
+ tpu_inference/executors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
+ tpu_inference/executors/ray_distributed_executor.py,sha256=VzAPBVb7c8zwGZFtn1OxnwxQTiZMfLnzeI1P7M69d5k,14888
42
+ tpu_inference/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
43
+ tpu_inference/experimental/llama3_jax_stashed.py,sha256=YK1oSIfto9ALo-HB45XfSrbq9XgVbE4m2C-9zRwmSzI,10913
44
+ tpu_inference/interfaces/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
+ tpu_inference/interfaces/cache.py,sha256=ZSNvjYpRjmm3RlsqENa3R9oHP0L7W4zv8nhCvpNGJLA,813
46
+ tpu_inference/interfaces/config.py,sha256=f0fJBbp5FAWjfS-gC5UK2ptrn42f-DMLK9y_QB9Hm_U,1211
47
+ tpu_inference/interfaces/config_parts.py,sha256=QuqV6LH8rPXmWuxraVFKmY08aTaZCy62ndM_vK9JkKQ,2105
48
+ tpu_inference/interfaces/engine.py,sha256=Z1Vxmf5tiKTT3LMMpMX73urMqK3Uc4ZvO0UW8oXAsxE,1446
49
+ tpu_inference/interfaces/outputs.py,sha256=ay9DXf_9JnaPc5kPJg3MYbO8frIQTJZxiugP7ow0gUI,580
50
+ tpu_inference/interfaces/params.py,sha256=Cp8MtBj3LW8-4h9J23AJO4wGvG3aOuFIq3YFS-OG8zA,364
51
+ tpu_inference/interfaces/platform.py,sha256=_EVTdilqpXJX2rRdypANuojOhDO0BCkUwekaxXQqDvQ,1833
52
+ tpu_inference/interfaces/request.py,sha256=DRkjdWo5wmkVwQlq9DqpMDPeVPmQd6dfyhN2_k8tezw,950
53
+ tpu_inference/interfaces/scheduler.py,sha256=cFBRkqVNXHrn-08Zvr9B23YTJUzSehy1rE-Fy2V5nvg,816
54
+ tpu_inference/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
+ tpu_inference/kernels/collectives/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
+ tpu_inference/kernels/collectives/all_gather_matmul.py,sha256=0OYLLjlDmkRYScl7lHRi0o___5I5iMiW1gso-_dWSbc,27255
57
+ tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py,sha256=KdaOIzTfIgUR0CcUTA46tpYH-cxPNoJx2cTMEvHx-Ac,1399
58
+ tpu_inference/kernels/collectives/util.py,sha256=LbLD6lOxuszbUsykF89gWQqEJUICCZsfzam3EJDPnFE,1859
59
+ tpu_inference/kernels/flash_attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
+ tpu_inference/kernels/flash_attention/kernel.py,sha256=n8gmAFVfchMXlyaSEj8xXJm6AadFt26edQihPRdithY,25897
61
+ tpu_inference/kernels/quantized_matmul/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
62
+ tpu_inference/kernels/quantized_matmul/kernel.py,sha256=4oEVUXgWOeOY-PfySHf-iEuUSd9J7GQk_rDSbxa7CXg,14086
63
+ tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py,sha256=3zhIm73JEE8qOty2_0v3AJlVz13k6qMB5wlXBDyC1EM,35130
64
+ tpu_inference/kernels/quantized_matmul/util.py,sha256=rf6nIiAj9I2cj4LDvtaZGhcLXEc94o2xgMWasnFaREM,1943
65
+ tpu_inference/kernels/ragged_paged_attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
66
+ tpu_inference/kernels/ragged_paged_attention/v2/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
67
+ tpu_inference/kernels/ragged_paged_attention/v2/kernel.py,sha256=OiQGAHhyggbp1PeuasPymopFohKOJjGXcpq9p_S8UWA,34940
68
+ tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py,sha256=vGp2ZWODTbjyG9z2z0Qf_BX-wYHd5bUybnc_DtOz0nI,10995
69
+ tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py,sha256=mw80bXBGenroGdrITV0F_EaI2s-Z9KWwqU9WodvJg14,97919
70
+ tpu_inference/kernels/ragged_paged_attention/v3/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
71
+ tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=zc-re4Knsdcfvt2oRO5KGD9-dJs0P8GVJ3yGtclHU2A,54740
72
+ tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=KR2UFpCWjsXCmfMcxxV3yV2DVJp5xcEomOtOKYnSL78,131402
73
+ tpu_inference/kernels/ragged_paged_attention/v3/util.py,sha256=5ij66Rl7YsjTCH1UERP1W-XXC57sL6ZVPQdTLhMtKHQ,1010
74
+ tpu_inference/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
75
+ tpu_inference/layers/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
76
+ tpu_inference/layers/common/attention_metadata.py,sha256=St8ZatbY1D7xQACKJH459jMgp3oTP3AQ36mi9FZdrPU,850
77
+ tpu_inference/layers/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
+ tpu_inference/layers/jax/attention_interface.py,sha256=bXBD8C8RTYTyLJOIGcKd1jH_ZruM0jabLj4n98RIKSA,12003
79
+ tpu_inference/layers/jax/base.py,sha256=Vhts6ZMwNCZ8LbnEXeB0rl3nHdS5hDJWX7HEa7Fl7yE,5775
80
+ tpu_inference/layers/jax/binary_search.py,sha256=ZQi-z1wG6WTcfVQXeTGOZokX4K1DSf9kCzqfrhEU8lk,12320
81
+ tpu_inference/layers/jax/constants.py,sha256=NcYg0zAf3ClfP7YMYdYu_F1GngOzZaIxIAHBZDunKw4,2755
82
+ tpu_inference/layers/jax/layers.py,sha256=yv_lC2tbJuzVL-OaXYooX82Ys8hWZATeH9M78coJ3VI,10633
83
+ tpu_inference/layers/jax/misc.py,sha256=znKv1Nuq_LgYpaIu0qlzUVDgQWnjjG7aqPJGM8kuwcw,566
84
+ tpu_inference/layers/jax/rope.py,sha256=3ZyR06vwliipkynHHrvcK-Q_aRhvQKDYBOqBYr3oWM8,7029
85
+ tpu_inference/layers/jax/rope_interface.py,sha256=X0SruXizlCHGnssFujC1pL07UC4Vsp7-gdBy_Q7JZhI,8375
86
+ tpu_inference/layers/jax/sharding.py,sha256=L0Uh92oLaXFNNQ0qqzNtBD3x3wnTRexQt8GzsCvqH1k,17874
87
+ tpu_inference/layers/jax/transformer_block.py,sha256=MBN4_hYCGq_-eyomGVUqplBZugZ2LBWUFOgM1UtUxFY,2952
88
+ tpu_inference/layers/jax/attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
89
+ tpu_inference/layers/jax/attention/attention.py,sha256=KsGuQpOu7yUpimIr5XBniHKaa2ohx_Ke2YaCOvAG3jc,9837
90
+ tpu_inference/layers/jax/attention/deepseek_v3_attention.py,sha256=YlagoBMwINv2KRH1dr4oEcH_cQ9QMPB55nO2FQZsWs0,14010
91
+ tpu_inference/layers/jax/attention/llama4_attention.py,sha256=VvUmfBxQEbHf3F2BrcYDUnq5abj7CSDYeRsNx_eVAh0,6162
92
+ tpu_inference/layers/jax/moe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
93
+ tpu_inference/layers/jax/moe/deepseek_v3_moe.py,sha256=Q6CuwwiZtWYm6iUee1wJoDJrwJE6_bcznTK2HrtXb0M,26089
94
+ tpu_inference/layers/jax/moe/moe.py,sha256=cA8R1rjbBwNEoNlsPWjeIBB9nvaRDwlEdwQTVg6lTpY,8762
95
+ tpu_inference/layers/jax/sample/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
96
+ tpu_inference/layers/jax/sample/rejection_sampler.py,sha256=IRfVWjkbVXp9Sv1YrGMMh-LYx1AwbY-3FTXEO1-Ue9g,20423
97
+ tpu_inference/layers/jax/sample/sampling.py,sha256=-47SC7AqU4UgyO91zAdYXTgrBfdlQ9I89HFZKwU0eQA,3223
98
+ tpu_inference/layers/jax/sample/sampling_metadata.py,sha256=c3jHNjh1hkFJ5gxGTEk0qBOZnICeY3EELViF5Omp_Nc,2252
99
+ tpu_inference/layers/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
100
+ tpu_inference/layers/vllm/attention.py,sha256=UVhuNCCrz6jdLNotjGtgaR4CVZ4zNmq5VhsiuOTi6_I,6649
101
+ tpu_inference/layers/vllm/fused_moe.py,sha256=ld_-sIHRdUY2tTTHrzHzCahVxH4P0sZVZrxYBQYSJhE,17455
102
+ tpu_inference/layers/vllm/linear_common.py,sha256=_YlJtbdaYcck_j-gFLos_k0ycktVWxT8Qo57tR2YqJ8,7749
103
+ tpu_inference/layers/vllm/sharding.py,sha256=Ck2OzNiucHtrEutDqPQNteu8MEm6isIkE8U5ziowHgM,5779
104
+ tpu_inference/layers/vllm/quantization/__init__.py,sha256=UGv9cJftrBNoC0pU8SLnTLq3zvqMcolN5YJ6n_J5jf4,1392
105
+ tpu_inference/layers/vllm/quantization/awq.py,sha256=78H4AYgbvLCrW-5bGbn9_WM1J8KnRzVOInfKSW_QmzQ,8476
106
+ tpu_inference/layers/vllm/quantization/common.py,sha256=wm3pge6XMTMsLK7_SSdgBP0PvQzz-1mrqN2I6xMqzrc,4218
107
+ tpu_inference/layers/vllm/quantization/unquantized.py,sha256=QIN6lWfVhN4ikUQlDbD8GhkZcLp1-s1Zi66aqKenmeo,10062
108
+ tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
109
+ tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py,sha256=ifC6UsCY0tB6BO7X-PWtw-ikUc5IhcPcLvo0_RFrEsM,5253
110
+ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
111
+ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py,sha256=6sQvsxiWdi5Vte8V9vrQ2abaqGqWpq-mtzU7lGAo-ac,8759
112
+ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=4y7lYgybpXszpCAtxGFhR8LDEbEoCCeo3DfUSOXxhaQ,5202
113
+ tpu_inference/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
114
+ tpu_inference/lora/torch_lora_ops.py,sha256=pr3N7DVfkn3ANijUC6dBoiCtIJW4fdJpKdC3zWBUsxE,3121
115
+ tpu_inference/lora/torch_punica_tpu.py,sha256=ZfwWpPhkz4VQyxX9KeClx1hhchglsCWl0xpcGZsuMG0,12522
116
+ tpu_inference/mock/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
117
+ tpu_inference/mock/vllm_config_utils.py,sha256=FlQshLjoHdgs3C66tYHYbKFUjbk9DhUwY-7HibZk0fI,878
118
+ tpu_inference/mock/vllm_envs.py,sha256=hHtbFOM45T5EB2tEGecMGbJA0qOI9dmNYcjANgtah98,51477
119
+ tpu_inference/mock/vllm_logger.py,sha256=vUGnN5nKT--ZvU15YCzODUM_FGiXKhcrrjDGjeN00RQ,7297
120
+ tpu_inference/mock/vllm_logging_utils.py,sha256=TEUmKj3xHiLzHBnFqAujcxH0t2hBQ04sUaho2RyORnk,486
121
+ tpu_inference/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
122
+ tpu_inference/models/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
123
+ tpu_inference/models/common/model_loader.py,sha256=kOwc5Dyn433U0F-qZU1D0_k5USkMTY5Em0_WvQfjIYc,17661
124
+ tpu_inference/models/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
125
+ tpu_inference/models/jax/deepseek_v3.py,sha256=735PSgqxxrYL9JIsohhUXimjSNYMeNlepfRLrYHZ9us,40038
126
+ tpu_inference/models/jax/llama3.py,sha256=bi-wIgZxR9h_DwoYHczPZXqrcvbzCVwnANuKnak6HcI,13024
127
+ tpu_inference/models/jax/llama4.py,sha256=WMs4gQxbkEZXo7beVJSwPNyZX0AR6prpSE7RGVb9U74,21733
128
+ tpu_inference/models/jax/llama_eagle3.py,sha256=STUkAK6XEA7JM3i_Lx36-t5BhkAGeW_xYiq3zYhHP1A,12297
129
+ tpu_inference/models/jax/phi3.py,sha256=Oz68PE2Z1t8wTed95_w0KMIXfnfV72ZwXugNOdWOV5w,13576
130
+ tpu_inference/models/jax/qwen2.py,sha256=RYb0hMKzPnFOAyhqbztoNlSrFIlRa74fYqSNecA2VOY,13354
131
+ tpu_inference/models/jax/qwen2_5_vl.py,sha256=GrUlM16EWsaGPpSnn1KhjcrAHfeJeC1Z3cVefw0-ynQ,38522
132
+ tpu_inference/models/jax/qwen3.py,sha256=SOL-Pvp56IrMxqXpPf5EFacBI6AJNlqf4Zrr1pkabGw,10994
133
+ tpu_inference/models/jax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
134
+ tpu_inference/models/jax/utils/file_utils.py,sha256=NOuSC3YFnZpf3CZgYdghbbiNYJt42zgjlEYbOZIVct4,2840
135
+ tpu_inference/models/jax/utils/multi_modal_utils.py,sha256=huW_yfntOJ_3ZXYUN1tJtmeK7EMoOBZExTZQtvfHOdk,6189
136
+ tpu_inference/models/jax/utils/weight_utils.py,sha256=lZIW-39BA6GzdMZ_nr-CapBttLsfEajJvMJo8ykr0B0,19507
137
+ tpu_inference/models/jax/utils/quantization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
138
+ tpu_inference/models/jax/utils/quantization/quantization_utils.py,sha256=hpzEzosiGi_02bgBXzW-AwZnKEiP_NPiKvpLSIPNjD4,24519
139
+ tpu_inference/models/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
140
+ tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=CyA9Gk8rmL1_FmIJ0NQcsutkwZn_DBZlzwuib2M2HuI,11141
141
+ tpu_inference/models/vllm/vllm_model_wrapper_context.py,sha256=yxlJHPmRQIAwlb1MmHK3xfXokgIkJ-evNU4PgyoJUdg,1187
142
+ tpu_inference/platforms/__init__.py,sha256=2m4E-nxkBhYZFG23Ni4_AFpZe8xQTimdRltkrNzp7WA,69
143
+ tpu_inference/platforms/tpu_jax.py,sha256=oKQFXjNF6cK2QZT7bqgb50oBwr-FN4VO0VdQXl1TQmE,9941
144
+ tpu_inference/runner/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
145
+ tpu_inference/runner/block_table_jax.py,sha256=HCjrOMpsWk_x3lW--AOPPUHBHplIzGioMTuHKFxtr6A,4164
146
+ tpu_inference/runner/compilation_manager.py,sha256=16n36Ne4LbmPei8UIAnUMw4TrLcBpe7a5Kvc3oibqcA,30904
147
+ tpu_inference/runner/input_batch_jax.py,sha256=lqFGhZ3w92MPzpiGJ6bNUsQC8X1AP8JpKiWMeXs5tto,18260
148
+ tpu_inference/runner/kv_cache.py,sha256=dU7DRJn0--qgPLV00jCIw4sabSf007mO5kCWnNrNeDI,3952
149
+ tpu_inference/runner/kv_cache_manager.py,sha256=bDkbfpQ41L-n6R-LrseZE85DIuTtu4vbt4mCj1MJa48,21467
150
+ tpu_inference/runner/lora_utils.py,sha256=XFNHPJvZe4e87tbyyKpOY9Vb28M9Rza3HXHNsem7jVg,3872
151
+ tpu_inference/runner/multimodal_manager.py,sha256=2-QQcLuWikP7JgmC3tGovNDYfvikZl3tWyAiX4x8YDc,9283
152
+ tpu_inference/runner/persistent_batch_manager.py,sha256=Zo8w2EdFZTSQtx6DCl57P8kQWkebquXl22RGIX2yqec,11160
153
+ tpu_inference/runner/speculative_decoding_manager.py,sha256=_2oAwo_8e4N-FJXjC9oR-fsO8WjukCdvQhPH4R8B-c4,10274
154
+ tpu_inference/runner/structured_decoding_manager.py,sha256=0SIoa5orxDcx76ziatKJ-GfnTAVIPCPTaMS15nxRR5U,3673
155
+ tpu_inference/runner/tpu_jax_runner.py,sha256=HA-PBThgXv0GHfFxA9ltQr7fQFDw8rwE4SKMsYV0zMI,34285
156
+ tpu_inference/runner/utils.py,sha256=5QcZW8an8EHs_zHKzIGqIf3ltAevusdwgaLLFrB9rc8,17131
157
+ tpu_inference/spec_decode/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
158
+ tpu_inference/spec_decode/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
159
+ tpu_inference/spec_decode/jax/eagle3.py,sha256=PgIAJMuEyy61Tz4SQ6QZqB-B4t4-RYDmUIoHDyOHEjA,15204
160
+ tpu_inference/worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
161
+ tpu_inference/worker/_temporary_vllm_compat.py,sha256=GpF8TuPMDbc0fvIxe7XWEe69FES_F-jJnmcaTgf2dO8,5182
162
+ tpu_inference/worker/base.py,sha256=0Dd3CKk3e7DgvzhfH4M-9-MEQNyYh4zUWSO4tnHFd6s,3140
163
+ tpu_inference/worker/tpu_worker_jax.py,sha256=7b2QVTSbveifm9_BgNnVGwEvh5zPrEi1qiXXTwFFODc,14093
164
+ tpu_inference-0.11.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
165
+ tpu_inference-0.11.1.dist-info/METADATA,sha256=uKyRzPptKu13NN6_lYOinPlLYk57ZUFleECr2JDgLrs,5393
166
+ tpu_inference-0.11.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
167
+ tpu_inference-0.11.1.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
168
+ tpu_inference-0.11.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+