tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (59) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -1
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/test_envs.py +110 -12
  9. tests/test_quantization.py +3 -0
  10. tests/test_utils.py +1 -2
  11. tpu_inference/distributed/tpu_connector.py +1 -1
  12. tpu_inference/envs.py +92 -8
  13. tpu_inference/executors/ray_distributed_executor.py +5 -1
  14. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  15. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  16. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  17. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  18. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  19. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  20. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  21. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
  22. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
  23. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  25. tpu_inference/layers/common/attention_interface.py +7 -1
  26. tpu_inference/layers/common/sharding.py +11 -7
  27. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  28. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  29. tpu_inference/layers/vllm/fused_moe.py +170 -208
  30. tpu_inference/layers/vllm/linear_common.py +43 -21
  31. tpu_inference/layers/vllm/quantization/common.py +11 -6
  32. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  33. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  34. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  35. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  36. tpu_inference/models/common/model_loader.py +78 -22
  37. tpu_inference/models/jax/deepseek_v3.py +185 -64
  38. tpu_inference/models/jax/gpt_oss.py +3 -3
  39. tpu_inference/models/jax/llama_eagle3.py +4 -5
  40. tpu_inference/models/jax/qwen2_5_vl.py +161 -47
  41. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  42. tpu_inference/models/jax/utils/weight_utils.py +203 -155
  43. tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
  44. tpu_inference/platforms/tpu_platform.py +29 -48
  45. tpu_inference/runner/compilation_manager.py +112 -46
  46. tpu_inference/runner/kv_cache.py +40 -20
  47. tpu_inference/runner/kv_cache_manager.py +40 -31
  48. tpu_inference/runner/persistent_batch_manager.py +40 -2
  49. tpu_inference/runner/structured_decoding_manager.py +2 -3
  50. tpu_inference/runner/tpu_runner.py +94 -51
  51. tpu_inference/runner/utils.py +2 -2
  52. tpu_inference/spec_decode/jax/eagle3.py +71 -22
  53. tpu_inference/utils.py +41 -14
  54. tpu_inference/worker/tpu_worker.py +43 -45
  55. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
  56. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
  57. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  58. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  59. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -6,13 +6,19 @@ from typing import Any, Optional
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
+ from flax import nnx
10
+ from jax import lax
11
+ from jax.sharding import NamedSharding, PartitionSpec
9
12
  from vllm.config import VllmConfig
10
13
 
11
14
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
15
+ from tpu_inference.logger import init_logger
12
16
  from tpu_inference.models.common.model_loader import get_model
13
17
  from tpu_inference.runner import utils as runner_utils
14
18
  from tpu_inference.utils import device_array
15
19
 
20
+ logger = init_logger(__name__)
21
+
16
22
 
17
23
  class Eagle3Proposer:
18
24
  """A proposer for speculative decoding using the Eagle3 method.
@@ -51,9 +57,22 @@ class Eagle3Proposer:
51
57
  """Loads the draft model."""
52
58
  self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
53
59
  self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
54
- if 'embed_tokens' in self.state.model:
55
- del self.state.model['embed_tokens']
56
- self.state.model.embed_tokens = target_model.model.embed
60
+
61
+ draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
62
+ if draft_embed_tokens is None or ~jnp.any(
63
+ draft_embed_tokens.embedding):
64
+ logger.info(
65
+ "Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
66
+ )
67
+ self.state.model.embed_tokens = target_model.model.embed
68
+ elif jnp.array_equal(draft_embed_tokens.embedding,
69
+ target_model.model.embed.embedding):
70
+ logger.info(
71
+ "Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
72
+ )
73
+ self.state.model.embed_tokens = target_model.model.embed
74
+ else:
75
+ logger.info("Draft model has its own embed_tokens.")
57
76
 
58
77
  @functools.partial(jax.jit, static_argnums=(0, ))
59
78
  def _prepare_input_ids(
@@ -111,6 +130,17 @@ class Eagle3Proposer:
111
130
  max_num_blocks_per_req)
112
131
  new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
113
132
 
133
+ positions = lax.with_sharding_constraint(
134
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
135
+ clamped_positions = lax.with_sharding_constraint(
136
+ clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
137
+ new_seq_lens = lax.with_sharding_constraint(
138
+ new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
139
+ query_start_loc = lax.with_sharding_constraint(
140
+ query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
141
+ new_block_tables = lax.with_sharding_constraint(
142
+ new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
143
+
114
144
  return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
115
145
 
116
146
  @functools.partial(jax.jit, static_argnums=(0, ))
@@ -122,6 +152,7 @@ class Eagle3Proposer:
122
152
  @functools.partial(jax.jit, static_argnums=(0, ))
123
153
  def _prepare_hidden_states_and_input_ids(
124
154
  self,
155
+ state: nnx.State,
125
156
  aux_hidden_states: tuple[jax.Array, ...],
126
157
  query_start_loc: jax.Array,
127
158
  target_token_ids: jax.Array,
@@ -130,7 +161,7 @@ class Eagle3Proposer:
130
161
  ) -> tuple[jax.Array, jax.Array, jax.Array]:
131
162
  target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
132
163
  target_hidden_states = self.combine_hidden_states_fn(
133
- self.state, target_hidden_states)
164
+ state, target_hidden_states)
134
165
 
135
166
  input_ids, last_token_indices = self._prepare_input_ids(
136
167
  query_start_loc, target_token_ids, next_token_ids, num_reqs)
@@ -177,8 +208,8 @@ class Eagle3Proposer:
177
208
  block_tables=device_array(
178
209
  self.mesh, block_tables))
179
210
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
180
- aux_hidden_states, attn_metadata.query_start_loc, input_ids,
181
- next_token_ids, num_reqs)
211
+ self.state, aux_hidden_states, attn_metadata.query_start_loc,
212
+ input_ids, next_token_ids, num_reqs)
182
213
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
183
214
 
184
215
  # Host copies from the metadata prepared by the runner.
@@ -242,12 +273,13 @@ class Eagle3Proposer:
242
273
 
243
274
  attn_metadata = replace(attn_metadata, block_tables=block_tables)
244
275
  return self._filter_token_and_prepare_initial_inputs(
245
- token_indices, query_start_loc, seq_lens, input_ids,
276
+ self.state, token_indices, query_start_loc, seq_lens, input_ids,
246
277
  aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
247
278
 
248
279
  @functools.partial(jax.jit, static_argnums=(0, ))
249
280
  def _filter_token_and_prepare_initial_inputs(
250
281
  self,
282
+ state: nnx.State,
251
283
  token_indices: jax.Array,
252
284
  query_start_loc: jax.Array,
253
285
  seq_lens: jax.Array,
@@ -275,35 +307,51 @@ class Eagle3Proposer:
275
307
  )
276
308
 
277
309
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
278
- [h[token_indices] for h in aux_hidden_states], query_start_loc,
279
- target_token_ids, next_token_ids, num_reqs)
310
+ state, [h[token_indices] for h in aux_hidden_states],
311
+ query_start_loc, target_token_ids, next_token_ids, num_reqs)
280
312
 
281
313
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
282
314
 
283
315
  @functools.partial(jax.jit, static_argnums=(0, ))
284
316
  def _select_draft_token_ids(
285
317
  self,
318
+ state: nnx.State,
286
319
  hidden_states: jax.Array,
287
320
  last_token_indices: jax.Array,
288
321
  ) -> jax.Array:
289
322
  sample_hidden_states = hidden_states[last_token_indices]
290
- return self._get_draft_token_ids(sample_hidden_states)
323
+ sample_hidden_states = lax.with_sharding_constraint(
324
+ sample_hidden_states,
325
+ NamedSharding(self.mesh, PartitionSpec(None, None)))
326
+ return self._get_draft_token_ids(state, sample_hidden_states)
291
327
 
292
328
  @functools.partial(jax.jit, static_argnums=(0, ))
293
- def _get_draft_token_ids(self, hidden_states: jax.Array) -> jax.Array:
329
+ def _get_draft_token_ids(self, state: nnx.State,
330
+ hidden_states: jax.Array) -> jax.Array:
294
331
  lora_metadata = None
295
- logits = self.compute_logits_fn(self.state, hidden_states,
296
- lora_metadata)
297
- return jnp.argmax(logits, axis=-1)
332
+ logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
333
+ draft_token_ids = jnp.argmax(logits, axis=-1)
334
+ return lax.with_sharding_constraint(
335
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
298
336
 
299
337
  @functools.partial(jax.jit, static_argnums=(0, ))
300
338
  def _select_inputs_for_loop_speculation(
301
- self, positions: jax.Array, residual: jax.Array,
339
+ self, state: nnx.State, positions: jax.Array, residual: jax.Array,
302
340
  hidden_states: jax.Array,
303
341
  last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
304
- return positions[last_token_indices], residual[
305
- last_token_indices], self._select_draft_token_ids(
306
- hidden_states, last_token_indices)
342
+ positions = positions[last_token_indices]
343
+ residual = residual[last_token_indices]
344
+ draft_token_ids = self._select_draft_token_ids(state, hidden_states,
345
+ last_token_indices)
346
+
347
+ positions = lax.with_sharding_constraint(
348
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
349
+ residual = lax.with_sharding_constraint(
350
+ residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
351
+ draft_token_ids = lax.with_sharding_constraint(
352
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
353
+
354
+ return positions, residual, draft_token_ids
307
355
 
308
356
  def propose(
309
357
  self,
@@ -330,11 +378,11 @@ class Eagle3Proposer:
330
378
 
331
379
  if self.num_speculative_tokens == 1:
332
380
  return kv_caches, self._select_draft_token_ids(
333
- hidden_states, last_token_indices)
381
+ self.state, hidden_states, last_token_indices)
334
382
 
335
383
  positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
336
- attn_metadata.input_positions, residual[0], hidden_states,
337
- last_token_indices)
384
+ self.state, attn_metadata.input_positions, residual[0],
385
+ hidden_states, last_token_indices)
338
386
 
339
387
  draft_token_ids_list = [draft_token_ids]
340
388
 
@@ -359,7 +407,8 @@ class Eagle3Proposer:
359
407
  attn_metadata,
360
408
  )
361
409
  hidden_states = residual[0]
362
- draft_token_ids = self._get_draft_token_ids(new_hidden_states)
410
+ draft_token_ids = self._get_draft_token_ids(
411
+ self.state, new_hidden_states)
363
412
  draft_token_ids_list.append(draft_token_ids)
364
413
 
365
414
  # [batch_size, num_speculative_tokens]
tpu_inference/utils.py CHANGED
@@ -8,11 +8,14 @@ from typing import Any, Callable, List, Tuple
8
8
  import jax
9
9
  import jax.numpy as jnp
10
10
  import numpy as np
11
+ import torch
11
12
  from jax._src import dtypes
12
13
  from jax._src import mesh as mesh_lib
13
14
  from jax._src import xla_bridge as xb
14
15
  from jax._src.lib import xla_client as xc
16
+ from jax._src.numpy.scalar_types import _ScalarMeta
15
17
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
18
+ from torchax.ops.mappings import j2t_dtype, t2j_dtype
16
19
  from vllm import envs as vllm_envs
17
20
  from vllm import utils
18
21
 
@@ -23,21 +26,44 @@ GBYTES = 1024 * 1024 * 1024
23
26
  TPU_HEAD_SIZE_ALIGNMENT = 128
24
27
  TPU_SECOND_LAST_MINOR = 8
25
28
 
26
- # This is used to translate from a string name for a dtype
27
- # to formal jax.numpy DType. One use case for this is
28
- # converting the `--kv_cache_dtype` flag to a dtype.
29
- TPU_STR_DTYPE_TO_JAX_DTYPE = {
30
- "bfloat16": jnp.bfloat16,
31
- "fp8": jnp.float8_e4m3fn,
32
- "fp8_e4m3": jnp.float8_e4m3,
33
- "fp8_e5m2": jnp.float8_e5m2,
34
- "int8": jnp.int8,
29
+ # Map vllm dtype string that doesn't exactly match jax dtype string name.
30
+ _VLLM_DTYPE_STR_TO_JAX_DTYPE = {
31
+ "fp8": jnp.float8_e4m3fn.dtype,
32
+ "fp8_e4m3": jnp.float8_e4m3fn.dtype,
33
+ "fp8_e5m2": jnp.float8_e5m2.dtype,
35
34
  }
36
35
 
36
+
37
+ def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
38
+ if isinstance(dtype, str):
39
+ if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
40
+ return dict_dtype
41
+ return jnp.dtype(dtype)
42
+ elif isinstance(dtype, torch.dtype):
43
+ return t2j_dtype(dtype)
44
+ elif isinstance(dtype, jnp.dtype):
45
+ return dtype
46
+ elif isinstance(dtype, _ScalarMeta):
47
+ return dtype.dtype
48
+ else:
49
+ raise ValueError(f"Argument is unsupported data type {type(dtype)}")
50
+
51
+
52
+ def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
53
+ # Use jax dtype as an intermediate dtype which we'll be used to convert it
54
+ # into torch dtype.
55
+ dtype = to_jax_dtype(dtype)
56
+ return j2t_dtype(dtype)
57
+
58
+
37
59
  _megacore = False
38
60
  logger = init_logger(__name__)
39
61
 
40
62
 
63
+ def align_to(unpadded_dim, pad_multiple):
64
+ return (unpadded_dim + pad_multiple - 1) // pad_multiple * pad_multiple
65
+
66
+
41
67
  def enable_megacore() -> None:
42
68
  global _megacore
43
69
  _megacore = True
@@ -164,7 +190,8 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
164
190
 
165
191
 
166
192
  def get_dtype_packing(dtype):
167
- bits = dtypes.bit_width(dtype)
193
+ bits = (dtypes.bit_width(dtype)
194
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
168
195
  return 32 // bits
169
196
 
170
197
 
@@ -249,11 +276,11 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
249
276
 
250
277
  def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
251
278
  """
252
- A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
279
+ A wrapper function of vllm.utils.hashing.get_hash_fn_by_name to support builtin
253
280
  """
254
281
  if hash_fn_name == "builtin":
255
282
  return hash
256
- return utils.get_hash_fn_by_name(hash_fn_name)
283
+ return utils.hashing.get_hash_fn_by_name(hash_fn_name)
257
284
 
258
285
 
259
286
  def quantize_kv(key: jax.Array, value: jax.Array,
@@ -295,8 +322,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
295
322
  Returns:
296
323
  jnp.dtype: The JAX dtype.
297
324
  """
298
- str_dtype = str_dtype.lower().strip()
299
- return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
325
+ # TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
326
+ return to_jax_dtype(str_dtype)
300
327
 
301
328
 
302
329
  def time_function(func):
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field
6
6
  from typing import Callable, Dict, Optional, Tuple
7
7
 
8
8
  import jax
9
- import jax.numpy as jnp
10
9
  import jaxlib
11
10
  import jaxtyping
12
11
  import vllm.envs as vllm_envs
@@ -19,7 +18,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
19
18
  from vllm.lora.request import LoRARequest
20
19
  from vllm.tasks import SupportedTask
21
20
  from vllm.v1 import utils as vllm_utils
22
- from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
21
+ from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks,
22
+ get_uniform_page_size)
23
23
  from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
24
24
  from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
25
25
  from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@@ -32,17 +32,11 @@ from tpu_inference.layers.common.sharding import ShardingConfigManager
32
32
  from tpu_inference.logger import init_logger
33
33
  from tpu_inference.models.jax.jax_intermediate_tensor import \
34
34
  JaxIntermediateTensors
35
- from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
35
+ from tpu_inference.runner.kv_cache import get_attention_page_size_bytes
36
36
  from tpu_inference.runner.tpu_runner import TPUModelRunner
37
37
 
38
38
  logger = init_logger(__name__)
39
39
 
40
- _DTYPE: dict[str, jnp.dtype] = {
41
- "bfloat16": jnp.bfloat16,
42
- "float": jnp.float32,
43
- "float32": jnp.float32,
44
- }
45
-
46
40
 
47
41
  @dataclass
48
42
  class PPConfig:
@@ -77,21 +71,6 @@ class TPUWorker:
77
71
  ip: str = "localhost",
78
72
  prev_worker_ip: str = "localhost",
79
73
  ):
80
- # If we use vLLM's model implementation in PyTorch, we should set it
81
- # with torch version of the dtype.
82
- impl = envs.MODEL_IMPL_TYPE
83
- if impl != "vllm": # vllm-pytorch implementation does not need this conversion
84
-
85
- # NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
86
- if not isinstance(vllm_config.model_config.dtype, str):
87
- logger.warning(
88
- "The model dtype is not properly set for JAX backend. "
89
- "Overwriting it to jnp.bfloat16")
90
- vllm_config.model_config.dtype = jnp.bfloat16
91
- else:
92
- vllm_config.model_config.dtype = _DTYPE.get(
93
- vllm_config.model_config.dtype, jnp.bfloat16)
94
-
95
74
  self.vllm_config = vllm_config
96
75
  self.model_config = vllm_config.model_config
97
76
  self.parallel_config = vllm_config.parallel_config
@@ -108,7 +87,7 @@ class TPUWorker:
108
87
 
109
88
  if self.model_config.trust_remote_code:
110
89
  # note: lazy import to avoid importing torch before initializing
111
- from vllm.utils import init_cached_hf_modules
90
+ from vllm.utils.import_utils import init_cached_hf_modules
112
91
 
113
92
  init_cached_hf_modules()
114
93
 
@@ -250,11 +229,20 @@ class TPUWorker:
250
229
  need_pp=self.parallel_config.pipeline_parallel_size > 1)
251
230
 
252
231
  ensure_kv_transfer_initialized(self.vllm_config)
253
- self.model_runner = TPUModelRunner(
254
- self.vllm_config, self.devices, self.rank, self.rank == 0,
255
- self.rank == self.pp_config.pp_world_size - 1)
232
+
233
+ is_first_rank = True
234
+ is_last_rank = True
235
+ if self.parallel_config.pipeline_parallel_size > 1:
236
+ is_first_rank = self.rank == 0
237
+ is_last_rank = self.rank == self.pp_config.pp_world_size - 1
238
+
239
+ self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
240
+ self.rank, is_first_rank,
241
+ is_last_rank)
256
242
  logger.info(f"Init worker | "
257
243
  f"rank={self.rank} | "
244
+ f"is_first_rank={is_first_rank} | "
245
+ f"is_last_rank={is_last_rank} | "
258
246
  f"node_id={get_node_id()} | "
259
247
  f"is_driver_worker={self.is_driver_worker} | "
260
248
  f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
@@ -357,7 +345,7 @@ class TPUWorker:
357
345
  if is_start:
358
346
  options = jax.profiler.ProfileOptions()
359
347
  # default: https://docs.jax.dev/en/latest/profiling.html#general-options
360
- options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
348
+ options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
361
349
  options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
362
350
  jax.profiler.start_trace(self.profile_dir,
363
351
  profiler_options=options)
@@ -395,32 +383,37 @@ class TPUWorker:
395
383
  # responsible for this translation. When vLLM can be modified, this
396
384
  # method should be changed to return `dict[str, AbstractKVCacheSpec]`,
397
385
  # and the vLLM side should be updated to handle the translation.
398
- kv_cache_specs = self.model_runner.get_kv_cache_spec()
386
+ kv_cache_spec = self.model_runner.get_kv_cache_spec()
399
387
 
400
- if len(kv_cache_specs) == 0:
401
- return kv_cache_specs
388
+ if len(kv_cache_spec) == 0:
389
+ return kv_cache_spec
402
390
 
403
391
  # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
404
392
  # feature that allows overriding page_size_bytes of KVCacheSpec.
405
- vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
406
- rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
407
- kv_cache_specs)
393
+ vllm_page_size_bytes = get_uniform_page_size(
394
+ list(kv_cache_spec.values()))
395
+ attention_page_size_bytes = get_attention_page_size_bytes(
396
+ self.model_runner.mesh, kv_cache_spec)
408
397
 
409
- if vllm_page_size_bytes != rpa_page_size_bytes:
398
+ if vllm_page_size_bytes != attention_page_size_bytes:
410
399
  logger.info(
411
- f"KV cache page size calculated by vLLM "
412
- f"({vllm_page_size_bytes} Bytes) does not match with actual "
413
- f"page size used by RPA kernel ({rpa_page_size_bytes} Bytes). "
414
- f"Recalculating number of KV blocks using actual page size.")
415
-
400
+ f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) "
401
+ f"does not match with actual page size used by the kernel "
402
+ f"({attention_page_size_bytes} Bytes). Recalculating number of "
403
+ f"KV blocks using actual page size.")
404
+
405
+ kv_cache_groups = get_kv_cache_groups(self.vllm_config,
406
+ kv_cache_spec)
407
+ group_size = max(
408
+ len(group.layer_names) for group in kv_cache_groups)
416
409
  available_memory = self.determine_available_memory()
417
- num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
418
- available_memory, rpa_page_size_bytes)
419
-
410
+ num_blocks = get_num_blocks(self.vllm_config, group_size,
411
+ available_memory,
412
+ attention_page_size_bytes)
420
413
  cache_config = self.vllm_config.cache_config
421
414
  cache_config.num_gpu_blocks_override = num_blocks
422
415
 
423
- return kv_cache_specs
416
+ return kv_cache_spec
424
417
 
425
418
  def initialize_from_config(
426
419
  self,
@@ -455,3 +448,8 @@ class TPUWorker:
455
448
 
456
449
  def shutdown(self) -> None:
457
450
  return
451
+
452
+ # Ray executor do not need handshake metadata
453
+ # as we pass the kv_parameters through proxy server
454
+ def get_kv_connector_handshake_metadata(self) -> None:
455
+ pass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.11.1.dev202511220812
3
+ Version: 0.12.0.dev20251213
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -14,7 +14,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
14
  Requires-Python: >=3.10
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: tpu-info==0.4.0
17
+ Requires-Dist: tpu-info==0.7.1
18
18
  Requires-Dist: yapf==0.43.0
19
19
  Requires-Dist: pytest
20
20
  Requires-Dist: pytest-mock
@@ -25,12 +25,13 @@ Requires-Dist: jax[tpu]==0.8.0
25
25
  Requires-Dist: jaxlib==0.8.0
26
26
  Requires-Dist: jaxtyping
27
27
  Requires-Dist: flax==0.11.1
28
- Requires-Dist: torchax==0.0.7
28
+ Requires-Dist: torchax==0.0.10
29
29
  Requires-Dist: qwix==0.1.1
30
30
  Requires-Dist: torchvision==0.24.0
31
31
  Requires-Dist: pathwaysutils
32
32
  Requires-Dist: parameterized
33
33
  Requires-Dist: numba==0.62.1
34
+ Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
34
35
  Dynamic: author
35
36
  Dynamic: classifier
36
37
  Dynamic: description
@@ -52,14 +53,12 @@ Dynamic: requires-python
52
53
 
53
54
  ---
54
55
 
55
- _Upcoming Events_ 🔥
56
-
57
- - Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
58
- - Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
59
- - Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
60
-
61
56
  _Latest News_ 🔥
62
57
 
58
+ - [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
59
+ - Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
60
+ - Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
61
+
63
62
  - [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
64
63
 
65
64
  <details>