tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -2,14 +2,15 @@
2
2
 
3
3
  import os
4
4
  import tempfile
5
+ from dataclasses import dataclass, field
5
6
  from typing import Callable, Dict, Optional, Tuple
6
7
 
7
8
  import jax
8
- import jax.numpy as jnp
9
9
  import jaxlib
10
10
  import jaxtyping
11
11
  import vllm.envs as vllm_envs
12
12
  from vllm.config import VllmConfig, set_current_vllm_config
13
+ from vllm.distributed import get_pp_group
13
14
  from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
14
15
  has_kv_transfer_group)
15
16
  from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
@@ -17,52 +18,59 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
17
18
  from vllm.lora.request import LoRARequest
18
19
  from vllm.tasks import SupportedTask
19
20
  from vllm.v1 import utils as vllm_utils
20
- from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
21
+ from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks,
22
+ get_uniform_page_size)
21
23
  from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
22
24
  from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
23
25
  from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
24
26
 
25
27
  from tpu_inference import envs, utils
28
+ from tpu_inference.distributed import jax_parallel_state
26
29
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
27
30
  get_node_id)
28
31
  from tpu_inference.layers.common.sharding import ShardingConfigManager
29
32
  from tpu_inference.logger import init_logger
30
- from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
33
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
34
+ JaxIntermediateTensors
35
+ from tpu_inference.runner.kv_cache import get_attention_page_size_bytes
31
36
  from tpu_inference.runner.tpu_runner import TPUModelRunner
32
37
 
33
38
  logger = init_logger(__name__)
34
39
 
35
- _DTYPE: dict[str, jnp.dtype] = {
36
- "bfloat16": jnp.bfloat16,
37
- "float": jnp.float32,
38
- "float32": jnp.float32,
39
- }
40
40
 
41
+ @dataclass
42
+ class PPConfig:
43
+ rank: int
44
+ ip: str
45
+ prev_worker_ip: str
46
+ pp_world_size: int
41
47
 
42
- class TPUWorker:
48
+ # default env vars for
49
+ # TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
50
+ # if PP is used in single host.
51
+ default_tpu_process_bounds: str = field(init=False)
52
+ default_tpu_chips_per_process_bounds: str = field(init=False)
53
+ default_tpu_visible_chips: str = field(init=False)
54
+
55
+ def __post_init__(self):
56
+ self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
57
+ self.default_tpu_chips_per_process_bounds = "1,1,1"
58
+ self.default_tpu_visible_chips = f"{self.rank}"
43
59
 
44
- def __init__(self,
45
- vllm_config: VllmConfig,
46
- local_rank: int,
47
- rank: int,
48
- distributed_init_method: str,
49
- is_driver_worker: bool = False,
50
- devices=None):
51
- # If we use vLLM's model implementation in PyTorch, we should set it
52
- # with torch version of the dtype.
53
- impl = envs.MODEL_IMPL_TYPE
54
- if impl != "vllm": # vllm-pytorch implementation does not need this conversion
55
-
56
- # NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
57
- if not isinstance(vllm_config.model_config.dtype, str):
58
- logger.warning(
59
- "The model dtype is not properly set for JAX backend. "
60
- "Overwriting it to jnp.bfloat16")
61
- vllm_config.model_config.dtype = jnp.bfloat16
62
- else:
63
- vllm_config.model_config.dtype = _DTYPE.get(
64
- vllm_config.model_config.dtype, jnp.bfloat16)
65
60
 
61
+ class TPUWorker:
62
+
63
+ def __init__(
64
+ self,
65
+ vllm_config: VllmConfig,
66
+ local_rank: int,
67
+ rank: int,
68
+ distributed_init_method: str,
69
+ is_driver_worker: bool = False,
70
+ devices=None,
71
+ ip: str = "localhost",
72
+ prev_worker_ip: str = "localhost",
73
+ ):
66
74
  self.vllm_config = vllm_config
67
75
  self.model_config = vllm_config.model_config
68
76
  self.parallel_config = vllm_config.parallel_config
@@ -74,10 +82,12 @@ class TPUWorker:
74
82
  self.devices = devices if devices is not None else []
75
83
  self.device_ranks = set(device.id for device in self.devices
76
84
  if isinstance(device, jaxlib._jax.Device))
85
+ self.pp_config = PPConfig(rank, ip, prev_worker_ip,
86
+ self.parallel_config.pipeline_parallel_size)
77
87
 
78
88
  if self.model_config.trust_remote_code:
79
89
  # note: lazy import to avoid importing torch before initializing
80
- from vllm.utils import init_cached_hf_modules
90
+ from vllm.utils.import_utils import init_cached_hf_modules
81
91
 
82
92
  init_cached_hf_modules()
83
93
 
@@ -86,7 +96,7 @@ class TPUWorker:
86
96
  # TPU Worker is initialized. The profiler server needs to start after
87
97
  # MP runtime is initialized.
88
98
  self.profile_dir = None
89
- if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
99
+ if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
90
100
  if not self.devices or 0 in self.device_ranks:
91
101
  # For TPU, we can only have 1 active profiler session for 1 profiler
92
102
  # server. So we only profile on rank0.
@@ -94,6 +104,14 @@ class TPUWorker:
94
104
  logger.info("Profiling enabled. Traces will be saved to: %s",
95
105
  self.profile_dir)
96
106
 
107
+ # For PP, we use MPMD so we want to profile every worker.
108
+ if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
109
+ self.profile_dir = os.path.join(
110
+ vllm_envs.VLLM_TORCH_PROFILER_DIR,
111
+ f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
112
+ )
113
+ os.makedirs(self.profile_dir, exist_ok=True)
114
+
97
115
  use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
98
116
  # Only one instance of profiler is allowed
99
117
  if use_jax_profiler_server and self.rank < 1:
@@ -105,31 +123,87 @@ class TPUWorker:
105
123
  )
106
124
  jax.profiler.start_server(jax_profiler_server_port)
107
125
 
126
+ # step_counter is used to calculate uuid to transfer intermediate tensors.
127
+ self.step_counter = 0
128
+
108
129
  def initialize_cache(self, num_gpu_blocks: int,
109
130
  num_cpu_blocks: int) -> None:
110
131
  self.cache_config.num_gpu_blocks = num_gpu_blocks
111
132
  self.cache_config.num_cpu_blocks = num_cpu_blocks
112
133
 
113
- def init_device(self):
134
+ def init_device(self,
135
+ tpu_process_bounds="",
136
+ tpu_chips_per_process_bounds="",
137
+ tpu_visible_chips=""):
138
+ # set tpu visible devices for Jax runtime in single host PP.
139
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
140
+ if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
141
+ tpu_ports = [
142
+ jax_parallel_state.BASE_JAX_PORT + i
143
+ for i in range(self.pp_config.pp_world_size)
144
+ ]
145
+ os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
146
+ [f"localhost:{port}" for port in tpu_ports])
147
+ os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
148
+ os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
149
+
150
+ # Note: Below is the setting for v6e8 host (8 chips of v6e)
151
+ # Replace with your own topology.
152
+ # There are 2 ways of subslicing a v6e
153
+ # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
154
+ # TPU_PROCESS_BOUNDS = "1,1,1"
155
+ # TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
156
+ # TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
157
+ # 2) 1 chip for each subslice, with at most 8 subslices,
158
+ # we can do TP=1, PP=1/2/3/4/5/6/7/8
159
+ os.environ[
160
+ "TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
161
+ if tpu_process_bounds \
162
+ else self.pp_config.default_tpu_process_bounds
163
+ os.environ[
164
+ "TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
165
+ if tpu_chips_per_process_bounds \
166
+ else self.pp_config.default_tpu_chips_per_process_bounds
167
+ os.environ[
168
+ "TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
169
+ if tpu_visible_chips \
170
+ else self.pp_config.default_tpu_visible_chips
171
+
114
172
  if not self.devices:
115
173
  sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
116
174
  device_indexes = sharding_config.device_indexes
117
175
  if device_indexes is not None and len(device_indexes) > 0:
118
176
  # Enforcing the devices sequence to be consistent with the specified device indexes
119
- all_devices = jax.devices()
120
- device_dict = {device.id: device for device in all_devices}
177
+ all_local_devices = jax.local_devices()
178
+ device_dict = {
179
+ device.id: device
180
+ for device in all_local_devices
181
+ }
121
182
  self.devices = []
122
183
  for device_index in device_indexes:
123
184
  device = device_dict[device_index]
124
185
  if device is None:
125
186
  raise KeyError(
126
187
  f"Device index {device_index} not found in "
127
- f"jax.devices() with IDs {list(device_dict.keys())}!"
188
+ f"jax.local_devices() with IDs {list(device_dict.keys())}!"
128
189
  )
129
190
  self.devices.append(device)
191
+ assert len(self.devices) >= sharding_config.total_devices
130
192
  self.devices = self.devices[:sharding_config.total_devices]
131
193
  else:
132
- self.devices = jax.devices()[:sharding_config.total_devices]
194
+ if self.pp_config.pp_world_size > 1:
195
+ # We only support a mixed tp + pp scenario that tp size is
196
+ # smaller or equals the total TPUs in one node
197
+ # say: we have 4 nodes with 4 TPUs each, we can only do pp:4, tp:4, but not pp:2, tp:8
198
+ assert jax.local_device_count(
199
+ ) >= sharding_config.total_devices
200
+ self.devices = jax.local_devices()[:sharding_config.
201
+ total_devices]
202
+ else:
203
+ # In a multi-host distributed env, say: Ray, local_device count may smaller
204
+ # than the total devices, we just choose the smaller set here.
205
+ self.devices = jax.devices()[:sharding_config.
206
+ total_devices]
133
207
 
134
208
  # Initialize the vLLM distribution layer as a single chip environment,
135
209
  # we'll swap the model's parallel modules with TPU SPMD equivalents.
@@ -146,15 +220,40 @@ class TPUWorker:
146
220
  tensor_model_parallel_size=1,
147
221
  pipeline_model_parallel_size=1,
148
222
  )
223
+
224
+ jax_parallel_state.init_pp_distributed_environment(
225
+ self.pp_config.ip,
226
+ self.rank,
227
+ self.parallel_config.pipeline_parallel_size,
228
+ self.devices[0],
229
+ need_pp=self.parallel_config.pipeline_parallel_size > 1)
230
+
149
231
  ensure_kv_transfer_initialized(self.vllm_config)
150
- self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
232
+
233
+ is_first_rank = True
234
+ is_last_rank = True
235
+ if self.parallel_config.pipeline_parallel_size > 1:
236
+ is_first_rank = self.rank == 0
237
+ is_last_rank = self.rank == self.pp_config.pp_world_size - 1
238
+
239
+ self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
240
+ self.rank, is_first_rank,
241
+ is_last_rank)
151
242
  logger.info(f"Init worker | "
152
243
  f"rank={self.rank} | "
244
+ f"is_first_rank={is_first_rank} | "
245
+ f"is_last_rank={is_last_rank} | "
153
246
  f"node_id={get_node_id()} | "
154
247
  f"is_driver_worker={self.is_driver_worker} | "
155
248
  f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
156
249
  vllm_utils.report_usage_stats(self.vllm_config)
157
250
 
251
+ def initialize_pp_transfer_connect(self):
252
+ if self.rank == 0:
253
+ return
254
+ jax_parallel_state.connect(self.pp_config.prev_worker_ip,
255
+ self.rank - 1)
256
+
158
257
  def determine_available_memory(self) -> int:
159
258
  gpu_memory_utilization = self.cache_config.gpu_memory_utilization
160
259
  hbm_usage = utils.hbm_usage_bytes(self.devices)
@@ -194,14 +293,39 @@ class TPUWorker:
194
293
  # deliberate, temporary compromise for the same reasons outlined in
195
294
  # the `get_kv_cache_spec` method.
196
295
 
197
- output = self.model_runner.execute_model(scheduler_output)
198
-
199
- # With a connector, the scheduler expects output from all workers
200
- # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
201
- if has_kv_transfer_group():
202
- return output
203
-
204
- return output if self.is_driver_worker else None
296
+ if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
297
+ intermediate_tensors = None
298
+ else:
299
+ # receive intermediate tensors
300
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
301
+ scheduler_output, self.rank - 1, self.step_counter)
302
+ # TODO: this method might only works for vllm model, not sure about jax models.
303
+ tensor_spec = self.model_runner.get_intermediate_tensor_spec(
304
+ scheduler_output.total_num_scheduled_tokens)
305
+ intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
306
+ uuid, tensor_spec)
307
+ intermediate_tensors = JaxIntermediateTensors(
308
+ intermediate_tensors_dict)
309
+
310
+ output = self.model_runner.execute_model(scheduler_output,
311
+ intermediate_tensors)
312
+
313
+ if isinstance(output, JaxIntermediateTensors):
314
+ assert self.parallel_config.pipeline_parallel_size > 1
315
+ assert not get_pp_group().is_last_rank
316
+ # send intermediate tensors
317
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
318
+ scheduler_output, self.rank, self.step_counter)
319
+ get_pp_group().send_tensor_dict(uuid, output.tensors)
320
+ self.step_counter += 1
321
+ return None
322
+ else:
323
+ self.step_counter += 1
324
+ # With a connector, the scheduler expects output from all workers
325
+ # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
326
+ if has_kv_transfer_group():
327
+ return output
328
+ return output if self.is_driver_worker else None
205
329
 
206
330
  def sample_tokens(self,
207
331
  grammar_output: GrammarOutput) -> ModelRunnerOutput:
@@ -221,7 +345,7 @@ class TPUWorker:
221
345
  if is_start:
222
346
  options = jax.profiler.ProfileOptions()
223
347
  # default: https://docs.jax.dev/en/latest/profiling.html#general-options
224
- options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
348
+ options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
225
349
  options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
226
350
  jax.profiler.start_trace(self.profile_dir,
227
351
  profiler_options=options)
@@ -259,32 +383,37 @@ class TPUWorker:
259
383
  # responsible for this translation. When vLLM can be modified, this
260
384
  # method should be changed to return `dict[str, AbstractKVCacheSpec]`,
261
385
  # and the vLLM side should be updated to handle the translation.
262
- kv_cache_specs = self.model_runner.get_kv_cache_spec()
386
+ kv_cache_spec = self.model_runner.get_kv_cache_spec()
263
387
 
264
- if len(kv_cache_specs) == 0:
265
- return kv_cache_specs
388
+ if len(kv_cache_spec) == 0:
389
+ return kv_cache_spec
266
390
 
267
391
  # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
268
392
  # feature that allows overriding page_size_bytes of KVCacheSpec.
269
- vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
270
- rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
271
- kv_cache_specs)
393
+ vllm_page_size_bytes = get_uniform_page_size(
394
+ list(kv_cache_spec.values()))
395
+ attention_page_size_bytes = get_attention_page_size_bytes(
396
+ self.model_runner.mesh, kv_cache_spec)
272
397
 
273
- if vllm_page_size_bytes != rpa_page_size_bytes:
398
+ if vllm_page_size_bytes != attention_page_size_bytes:
274
399
  logger.info(
275
- f"KV cache page size calculated by vLLM "
276
- f"({vllm_page_size_bytes} Bytes) does not match with actual "
277
- f"page size used by RPA kernel ({rpa_page_size_bytes} Bytes). "
278
- f"Recalculating number of KV blocks using actual page size.")
279
-
400
+ f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) "
401
+ f"does not match with actual page size used by the kernel "
402
+ f"({attention_page_size_bytes} Bytes). Recalculating number of "
403
+ f"KV blocks using actual page size.")
404
+
405
+ kv_cache_groups = get_kv_cache_groups(self.vllm_config,
406
+ kv_cache_spec)
407
+ group_size = max(
408
+ len(group.layer_names) for group in kv_cache_groups)
280
409
  available_memory = self.determine_available_memory()
281
- num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
282
- available_memory, rpa_page_size_bytes)
283
-
410
+ num_blocks = get_num_blocks(self.vllm_config, group_size,
411
+ available_memory,
412
+ attention_page_size_bytes)
284
413
  cache_config = self.vllm_config.cache_config
285
414
  cache_config.num_gpu_blocks_override = num_blocks
286
415
 
287
- return kv_cache_specs
416
+ return kv_cache_spec
288
417
 
289
418
  def initialize_from_config(
290
419
  self,
@@ -319,3 +448,8 @@ class TPUWorker:
319
448
 
320
449
  def shutdown(self) -> None:
321
450
  return
451
+
452
+ # Ray executor do not need handshake metadata
453
+ # as we pass the kv_parameters through proxy server
454
+ def get_kv_connector_handshake_metadata(self) -> None:
455
+ pass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.11.1.dev202511180814
3
+ Version: 0.12.0.dev20251213
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -14,7 +14,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
14
  Requires-Python: >=3.10
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: tpu-info==0.4.0
17
+ Requires-Dist: tpu-info==0.7.1
18
18
  Requires-Dist: yapf==0.43.0
19
19
  Requires-Dist: pytest
20
20
  Requires-Dist: pytest-mock
@@ -25,12 +25,13 @@ Requires-Dist: jax[tpu]==0.8.0
25
25
  Requires-Dist: jaxlib==0.8.0
26
26
  Requires-Dist: jaxtyping
27
27
  Requires-Dist: flax==0.11.1
28
- Requires-Dist: torchax==0.0.7
28
+ Requires-Dist: torchax==0.0.10
29
29
  Requires-Dist: qwix==0.1.1
30
- Requires-Dist: torchvision==0.23.0
30
+ Requires-Dist: torchvision==0.24.0
31
31
  Requires-Dist: pathwaysutils
32
32
  Requires-Dist: parameterized
33
33
  Requires-Dist: numba==0.62.1
34
+ Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
34
35
  Dynamic: author
35
36
  Dynamic: classifier
36
37
  Dynamic: description
@@ -52,14 +53,12 @@ Dynamic: requires-python
52
53
 
53
54
  ---
54
55
 
55
- _Upcoming Events_ 🔥
56
-
57
- - Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
58
- - Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
59
- - Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
60
-
61
56
  _Latest News_ 🔥
62
57
 
58
+ - [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
59
+ - Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
60
+ - Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
61
+
63
62
  - [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
64
63
 
65
64
  <details>