tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (59) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -1
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/test_envs.py +110 -12
  9. tests/test_quantization.py +3 -0
  10. tests/test_utils.py +1 -2
  11. tpu_inference/distributed/tpu_connector.py +1 -1
  12. tpu_inference/envs.py +92 -8
  13. tpu_inference/executors/ray_distributed_executor.py +5 -1
  14. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  15. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  16. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  17. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  18. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  19. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  20. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  21. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
  22. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
  23. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  25. tpu_inference/layers/common/attention_interface.py +7 -1
  26. tpu_inference/layers/common/sharding.py +11 -7
  27. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  28. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  29. tpu_inference/layers/vllm/fused_moe.py +170 -208
  30. tpu_inference/layers/vllm/linear_common.py +43 -21
  31. tpu_inference/layers/vllm/quantization/common.py +11 -6
  32. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  33. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  34. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  35. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  36. tpu_inference/models/common/model_loader.py +78 -22
  37. tpu_inference/models/jax/deepseek_v3.py +185 -64
  38. tpu_inference/models/jax/gpt_oss.py +3 -3
  39. tpu_inference/models/jax/llama_eagle3.py +4 -5
  40. tpu_inference/models/jax/qwen2_5_vl.py +161 -47
  41. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  42. tpu_inference/models/jax/utils/weight_utils.py +203 -155
  43. tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
  44. tpu_inference/platforms/tpu_platform.py +29 -48
  45. tpu_inference/runner/compilation_manager.py +112 -46
  46. tpu_inference/runner/kv_cache.py +40 -20
  47. tpu_inference/runner/kv_cache_manager.py +40 -31
  48. tpu_inference/runner/persistent_batch_manager.py +40 -2
  49. tpu_inference/runner/structured_decoding_manager.py +2 -3
  50. tpu_inference/runner/tpu_runner.py +94 -51
  51. tpu_inference/runner/utils.py +2 -2
  52. tpu_inference/spec_decode/jax/eagle3.py +71 -22
  53. tpu_inference/utils.py +41 -14
  54. tpu_inference/worker/tpu_worker.py +43 -45
  55. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
  56. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
  57. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  58. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  59. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,8 @@ import numpy as np
7
7
  import vllm.envs as envs
8
8
  from jax.sharding import NamedSharding, PartitionSpec
9
9
  from torchax.ops.mappings import t2j_dtype
10
- from vllm.attention import Attention
11
10
  from vllm.attention.backends.abstract import AttentionType
11
+ from vllm.attention.layer import Attention
12
12
  from vllm.config import get_layers_from_vllm_config
13
13
  from vllm.utils.math_utils import cdiv
14
14
  from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@@ -39,20 +39,30 @@ class KVCacheManager:
39
39
  # means this layer will perform attention using the keys and values
40
40
  # from the KV cache of `shared_kv_cache_layers[layer_name]`.
41
41
  self.shared_kv_cache_layers: dict[str, str] = {}
42
+ self.use_mla = self.runner.model_config.use_mla
42
43
 
43
44
  def get_kv_cache_spec(self):
44
45
  # TODO(xiang): this hack tricks engine core to init successfully
45
46
  block_size = self.runner.cache_config.block_size
46
- use_mla = self.runner.model_config.use_mla
47
47
  kv_cache_spec: dict[str, KVCacheSpec] = {}
48
48
 
49
49
  # If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
50
50
  # attention into compilation config.
51
51
  # Use FullAttentionSpec for each layer
52
52
  # TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
53
+ model_config = self.runner.model_config
54
+ if self.use_mla:
55
+ # Individually pad the RopE and latents
56
+ qk_rope_head_dim = getattr(model_config.hf_text_config,
57
+ "qk_rope_head_dim", 0)
58
+ padded_kv_lora_rank = common_utils.align_to(
59
+ model_config.hf_text_config.kv_lora_rank, 128)
60
+ padded_qk_rope_head_dim = common_utils.align_to(
61
+ qk_rope_head_dim, 128)
62
+ mla_head_size = padded_kv_lora_rank + padded_qk_rope_head_dim
63
+
53
64
  if len(self.runner.vllm_config.compilation_config.
54
65
  static_forward_context) == 0:
55
- model_config = self.runner.model_config
56
66
  parallel_config = self.runner.parallel_config
57
67
  # Pad num_kv_heads to multiple of TP size.
58
68
  num_kv_heads = common_utils.get_padded_num_heads(
@@ -61,11 +71,11 @@ class KVCacheManager:
61
71
  head_size = common_utils.get_padded_head_dim(
62
72
  model_config.get_head_size())
63
73
  for i in range(model_config.get_num_layers(parallel_config)):
64
- if use_mla:
74
+ if self.use_mla:
65
75
  kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
66
76
  block_size=block_size,
67
- num_kv_heads=num_kv_heads,
68
- head_size=head_size,
77
+ num_kv_heads=1,
78
+ head_size=mla_head_size,
69
79
  dtype=self.runner.kv_cache_dtype,
70
80
  cache_dtype_str=self.runner.vllm_config.cache_config.
71
81
  cache_dtype)
@@ -83,14 +93,13 @@ class KVCacheManager:
83
93
  self.runner.mesh.shape["model"])
84
94
  head_size = common_utils.get_padded_head_dim(
85
95
  hf_config.hidden_size // hf_config.num_attention_heads)
86
-
87
96
  # Eagle3 has only 1 layer
88
97
  for i in range(1):
89
- if use_mla:
90
- kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
98
+ if self.use_mla:
99
+ kv_cache_spec[f"draft_layer.{i}"] = MLAAttentionSpec(
91
100
  block_size=block_size,
92
- num_kv_heads=num_kv_heads,
93
- head_size=head_size,
101
+ num_kv_heads=1,
102
+ head_size=mla_head_size,
94
103
  dtype=self.runner.kv_cache_dtype,
95
104
  cache_dtype_str=self.runner.vllm_config.
96
105
  cache_config.cache_dtype)
@@ -104,6 +113,7 @@ class KVCacheManager:
104
113
  # Else propagate attention modules from compilation config.
105
114
  layers = get_layers_from_vllm_config(self.runner.vllm_config,
106
115
  Attention)
116
+ logger.warning(f"Compilation num_layers = {len(layers.items())}")
107
117
  for layer_name, attn_module in layers.items():
108
118
  if (kv_tgt_layer :=
109
119
  attn_module.kv_sharing_target_layer_name) is not None:
@@ -127,11 +137,11 @@ class KVCacheManager:
127
137
  attn_module.head_size),
128
138
  dtype=self.runner.kv_cache_dtype,
129
139
  sliding_window=attn_module.sliding_window)
130
- elif use_mla:
131
- kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
140
+ elif self.use_mla:
141
+ kv_cache_spec[layer_name] = MLAAttentionSpec(
132
142
  block_size=block_size,
133
- num_kv_heads=attn_module.num_kv_heads,
134
- head_size=attn_module.head_size,
143
+ num_kv_heads=1,
144
+ head_size=mla_head_size,
135
145
  dtype=self.runner.kv_cache_dtype,
136
146
  cache_dtype_str=self.runner.vllm_config.
137
147
  cache_config.cache_dtype)
@@ -198,14 +208,20 @@ class KVCacheManager:
198
208
  # num_blocks must be a multiple of dp_size
199
209
  num_blocks = (num_blocks // dp_size) * dp_size
200
210
  # NOTE: we'll multiply the num_kv_heads by 2 in the function
211
+ if self.use_mla:
212
+ head_size = self.runner.model_config.hf_config.kv_lora_rank + \
213
+ self.runner.model_config.hf_config.qk_rope_head_dim
214
+ else:
215
+ head_size = representative_spec.head_size
201
216
  kv_cache = create_kv_caches(
202
217
  num_blocks=num_blocks,
203
218
  block_size=representative_spec.block_size,
204
219
  num_kv_heads=representative_spec.num_kv_heads,
205
- head_size=representative_spec.head_size,
220
+ head_size=head_size,
206
221
  mesh=self.runner.mesh,
207
222
  layer_names=[f'kv_cache_tensor.{i}'],
208
223
  cache_dtype=t2j_dtype(representative_spec.dtype),
224
+ use_mla=self.use_mla,
209
225
  )[0]
210
226
  kv_caches.append(kv_cache)
211
227
  num_blocks_list.append(num_blocks)
@@ -289,13 +305,8 @@ class KVCacheManager:
289
305
 
290
306
  def _update_layer(cache, slices):
291
307
  """The function to apply to each layer's cache and slices."""
292
- reshaped_slices = slices.reshape(-1, 1, block_size,
293
- *slices.shape[1:])
294
- for (i, block_idx) in enumerate(block_numbers):
295
- cache = jax.lax.dynamic_update_slice_in_dim(cache,
296
- reshaped_slices[i],
297
- block_idx,
298
- axis=0)
308
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
309
+ cache.at[block_numbers].set(reshaped_slices)
299
310
  return cache
300
311
 
301
312
  return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
@@ -348,16 +359,12 @@ class KVCacheManager:
348
359
  """
349
360
  if block_ids == list(range(block_ids[0],
350
361
  block_ids[0] + len(block_ids))):
351
- with runner_utils.LatencyTracker(
352
- "BatchedGatherKVSlices-for-blocks"):
353
- batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
354
- self.runner.kv_caches, block_ids[0], len(block_ids))
362
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
363
+ self.runner.kv_caches, block_ids[0], len(block_ids))
355
364
 
356
365
  else:
357
- with runner_utils.LatencyTracker(
358
- "BatchedGatherKVSlices-for-blocks"):
359
- batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
360
- self.runner.kv_caches, jnp.array(block_ids))
366
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
367
+ self.runner.kv_caches, jnp.array(block_ids))
361
368
  return batched_kv_cache_per_layer
362
369
 
363
370
  def transfer_kv_cache(self,
@@ -446,6 +453,7 @@ class KVCacheManager:
446
453
  kv_cache_slices,
447
454
  start_block,
448
455
  )
456
+ jax.block_until_ready(self.runner.kv_caches)
449
457
  else:
450
458
  with runner_utils.LatencyTracker(
451
459
  f"JittedInsertKVCache-b{len(block_numbers)}"):
@@ -457,6 +465,7 @@ class KVCacheManager:
457
465
  kv_cache_slices,
458
466
  jnp.array(block_numbers),
459
467
  )
468
+ jax.block_until_ready(self.runner.kv_caches)
460
469
 
461
470
  logger.debug(
462
471
  f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
@@ -14,12 +14,13 @@ class PersistentBatchManager:
14
14
  def __init__(self, requests: Dict[str, CachedRequestState],
15
15
  input_batch: InputBatch, encoder_cache: Dict[str,
16
16
  'jax.Array'],
17
- uses_mrope: bool, model_config):
17
+ uses_mrope: bool, model_config, is_last_rank: bool):
18
18
  self.requests = requests
19
19
  self.input_batch = input_batch
20
20
  self.encoder_cache = encoder_cache
21
21
  self.uses_mrope = uses_mrope
22
22
  self.model_config = model_config
23
+ self.is_last_rank = is_last_rank
23
24
 
24
25
  def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
25
26
  """ Reorder the sheduled requests to RPA kernel friendly distribution
@@ -179,9 +180,35 @@ class PersistentBatchManager:
179
180
  num_computed_tokens = req_data.num_computed_tokens[i]
180
181
  new_block_ids = req_data.new_block_ids[i]
181
182
  resumed_from_preemption = req_data.resumed_from_preemption[i]
183
+ num_output_tokens = req_data.num_output_tokens[i]
182
184
 
183
185
  # Update the cached states.
184
186
  req_state.num_computed_tokens = num_computed_tokens
187
+ req_index = self.input_batch.req_id_to_index.get(req_id)
188
+
189
+ if not self.is_last_rank:
190
+ # When using PP, the scheduler sends the sampled tokens back,
191
+ # because there's no direct communication between the first-
192
+ # stage worker and the last-stage worker.
193
+ new_token_ids = req_data.new_token_ids[i]
194
+ # Add the sampled token(s) from the previous step (if any).
195
+ # This doesn't include "unverified" tokens like spec tokens.
196
+ num_new_tokens = (num_computed_tokens + len(new_token_ids) -
197
+ req_state.num_tokens)
198
+ if num_new_tokens == 1:
199
+ req_state.output_token_ids.append(new_token_ids[-1])
200
+ elif num_new_tokens > 0:
201
+ req_state.output_token_ids.extend(
202
+ new_token_ids[-num_new_tokens:])
203
+ elif num_output_tokens < len(req_state.output_token_ids):
204
+ del req_state.output_token_ids[num_output_tokens:]
205
+ if req_index is not None:
206
+ end_idx = (self.input_batch.num_prompt_tokens[req_index] +
207
+ num_output_tokens)
208
+ self.input_batch.num_tokens[req_index] = end_idx
209
+ self.input_batch.num_tokens_no_spec[req_index] = end_idx
210
+
211
+ # Update the block IDs.
185
212
  if not resumed_from_preemption:
186
213
  if new_block_ids is not None:
187
214
  # Append the new blocks to the existing block IDs.
@@ -194,7 +221,6 @@ class PersistentBatchManager:
194
221
  # Replace the existing block IDs with the new ones.
195
222
  req_state.block_ids = new_block_ids
196
223
 
197
- req_index = self.input_batch.req_id_to_index.get(req_id)
198
224
  if req_index is None:
199
225
  # The request is not in the persistent batch.
200
226
  # The request was either preempted and resumed later, or was not
@@ -209,6 +235,18 @@ class PersistentBatchManager:
209
235
  self.input_batch.block_table.append_row(
210
236
  new_block_ids, req_index)
211
237
 
238
+ # For the last rank, we don't need to update the token_ids_cpu
239
+ # because the sampled tokens are already cached.
240
+ if not self.is_last_rank:
241
+ start_token_index = num_computed_tokens
242
+ end_token_index = num_computed_tokens + len(new_token_ids)
243
+ self.input_batch.token_ids_cpu[
244
+ req_index,
245
+ start_token_index:end_token_index] = new_token_ids
246
+ self.input_batch.num_tokens_no_spec[
247
+ req_index] = end_token_index
248
+ self.input_batch.num_tokens[req_index] = end_token_index
249
+
212
250
  # Add spec_token_ids to token_ids_cpu.
213
251
  spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
214
252
  req_id, ())
@@ -61,11 +61,10 @@ class StructuredDecodingManager:
61
61
  self.runner.require_structured_out_cpu.fill(0)
62
62
 
63
63
  sorted_struct_requests = sorted(
64
- grammar_output.structured_output_request_ids.items(),
65
- key=lambda item: item[1])
64
+ grammar_output.structured_output_request_ids)
66
65
 
67
66
  cumulative_mask_idx = 0
68
- for req_id, _ in sorted_struct_requests:
67
+ for req_id in sorted_struct_requests:
69
68
  if req_id not in self.runner.input_batch.req_id_to_index:
70
69
  continue
71
70
  batch_index = self.runner.input_batch.req_id_to_index[req_id]
@@ -1,6 +1,6 @@
1
1
  import copy
2
2
  import functools
3
- import os
3
+ import logging
4
4
  import random
5
5
  from contextlib import nullcontext
6
6
  from dataclasses import dataclass
@@ -10,17 +10,15 @@ import jax
10
10
  import jax.numpy as jnp
11
11
  import jaxtyping
12
12
  import numpy as np
13
- import torch
14
- import vllm.envs as envs
13
+ import vllm.envs as vllm_envs
15
14
  from flax import nnx
16
15
  from jax.experimental import mesh_utils
17
16
  from jax.sharding import NamedSharding, PartitionSpec
18
- from torchax.ops.mappings import j2t_dtype
19
17
  from vllm.config import VllmConfig
18
+ from vllm.distributed import get_pp_group
20
19
  from vllm.distributed.kv_transfer import (get_kv_transfer_group,
21
20
  has_kv_transfer_group)
22
21
  from vllm.forward_context import set_forward_context
23
- from vllm.sequence import IntermediateTensors
24
22
  from vllm.tasks import SupportedTask
25
23
  from vllm.utils.math_utils import cdiv
26
24
  from vllm.v1.core.sched.output import GrammarOutput
@@ -35,6 +33,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
35
33
  KVConnectorModelRunnerMixin
36
34
  from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
37
35
 
36
+ import tpu_inference.envs as envs
38
37
  from tpu_inference import utils as common_utils
39
38
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
39
  from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
@@ -48,6 +47,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
48
47
  TPUSupportedSamplingMetadata
49
48
  from tpu_inference.logger import init_logger
50
49
  from tpu_inference.models.common.model_loader import get_model
50
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
51
+ JaxIntermediateTensors
51
52
  from tpu_inference.models.jax.utils.weight_utils import (
52
53
  shard_put, transfer_state_with_mappings)
53
54
  from tpu_inference.runner import utils as runner_utils
@@ -64,10 +65,12 @@ from tpu_inference.runner.structured_decoding_manager import \
64
65
  StructuredDecodingManager
65
66
  from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
66
67
  from tpu_inference.utils import (device_array, make_optimized_mesh,
67
- time_function)
68
+ time_function, to_jax_dtype, to_torch_dtype)
68
69
 
69
70
  logger = init_logger(__name__)
70
71
 
72
+ logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
73
+
71
74
  INVALID_TOKEN_ID = -1
72
75
  # Smallest output size
73
76
  MIN_NUM_SEQS = 8
@@ -78,17 +81,6 @@ DUMMY_METADATA = AttentionMetadata(
78
81
  request_distribution=[0, 0, 0],
79
82
  )
80
83
 
81
- TPU_STR_DTYPE_TO_TORCH_DTYPE = {
82
- "half": torch.half,
83
- "bfloat16": torch.bfloat16,
84
- "float": torch.float,
85
- "fp8": torch.float8_e4m3fn,
86
- "fp8_e4m3": torch.float8_e4m3fn,
87
- "fp8_e5m2": torch.float8_e5m2,
88
- "int8": torch.int8,
89
- "uint8": torch.uint8,
90
- }
91
-
92
84
 
93
85
  class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
94
86
  """Holds asynchronous model output specifically from a TPU runner.
@@ -243,6 +235,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
243
235
  self.maybe_forbid_compile = runner_utils.ForbidCompile(
244
236
  ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
245
237
  self.dp_size = self.vllm_config.sharding_config.total_dp_size
238
+ self.rank = rank
239
+ self.is_first_rank = is_first_rank
240
+ self.is_last_rank = is_last_rank
246
241
 
247
242
  self._init_random()
248
243
  self._init_mesh()
@@ -253,31 +248,21 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
253
248
 
254
249
  # Delegate functions to specific manager classes.
255
250
  self.compilation_manager = CompilationManager(self)
256
- self.speculative_decoding_manager = SpeculativeDecodingManager(self)
257
- self.structured_decoding_manager = StructuredDecodingManager(self)
251
+ if self.is_last_rank:
252
+ self.speculative_decoding_manager = SpeculativeDecodingManager(
253
+ self)
254
+ self.structured_decoding_manager = StructuredDecodingManager(self)
258
255
  self.kv_cache_manager = KVCacheManager(self)
259
256
  self.mm_manager = MultiModalManager(self)
260
257
  self.persistent_batch_manager = PersistentBatchManager(
261
258
  self.requests, self.input_batch, self.encoder_cache,
262
- self.uses_mrope, self.model_config)
259
+ self.uses_mrope, self.model_config, self.is_last_rank)
263
260
  self.lora_utils = LoraUtils(self)
264
261
 
265
- cache_config = self.cache_config
266
- if cache_config.cache_dtype == "auto":
267
- model_dtype = self.dtype
268
- if isinstance(model_dtype, str):
269
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
270
- elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
271
- self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
272
- elif isinstance(model_dtype, torch.dtype):
273
- self.kv_cache_dtype = model_dtype
274
- else:
275
- raise ValueError(
276
- "KV cache is unsupported for model_dtype of %s",
277
- model_dtype)
278
- else:
279
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
280
- cache_config.cache_dtype]
262
+ cache_dtype = self.cache_config.cache_dtype
263
+ if cache_dtype == "auto":
264
+ cache_dtype = self.dtype
265
+ self.kv_cache_dtype = to_torch_dtype(cache_dtype)
281
266
 
282
267
  self._pre_async_results: AsyncPreResults | None = None
283
268
  self._substitute_placeholder_token_fn = _substitute_placeholder_token
@@ -291,7 +276,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
291
276
  self.rng_key = jax.random.key(self.model_config.seed)
292
277
 
293
278
  def _init_mesh(self) -> None:
294
- if os.getenv("NEW_MODEL_DESIGN", False):
279
+ if envs.NEW_MODEL_DESIGN:
295
280
  self.mesh = self._create_new_model_mesh()
296
281
  else:
297
282
  # NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
@@ -302,7 +287,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
302
287
  logger.info(f"Init mesh | mesh={self.mesh}")
303
288
 
304
289
  def _create_new_model_mesh(self) -> jax.sharding.Mesh:
305
- num_slices = int(os.environ.get('NUM_SLICES', 1))
290
+ num_slices = envs.NUM_SLICES
306
291
 
307
292
  logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
308
293
  f"num_slices={num_slices}")
@@ -371,7 +356,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
371
356
  devices=self.devices)
372
357
 
373
358
  def _init_phased_profiling(self) -> None:
374
- self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
359
+ self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
375
360
  self.phase_based_profiler = None
376
361
  if self.phased_profiling_dir:
377
362
  self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
@@ -413,7 +398,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
413
398
  min_token_size=max(16, self.dp_size),
414
399
  max_token_size=scheduler_config.max_num_batched_tokens *
415
400
  self.dp_size,
416
- padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
401
+ padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
417
402
  self.num_tokens_paddings_per_dp = [
418
403
  padding // self.dp_size for padding in self.num_tokens_paddings
419
404
  ]
@@ -555,12 +540,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
555
540
  def execute_model(
556
541
  self,
557
542
  scheduler_output: "VllmSchedulerOutput",
558
- intermediate_tensors: Optional[IntermediateTensors] = None,
559
- ) -> ModelRunnerOutput | None:
543
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
544
+ ) -> ModelRunnerOutput | JaxIntermediateTensors | None:
560
545
  if self.execute_model_state is not None:
561
546
  raise RuntimeError("State error: sample_tokens() must be called "
562
547
  "after execute_model() returns None.")
563
- _, output = self._execute_model(scheduler_output)
548
+ _, output = self._execute_model(scheduler_output, intermediate_tensors)
564
549
  return output
565
550
 
566
551
  def sample_tokens(
@@ -686,7 +671,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
686
671
  def _execute_model(
687
672
  self,
688
673
  scheduler_output: "VllmSchedulerOutput",
689
- ) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
674
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
675
+ ) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
676
+ | None]:
690
677
  self.persistent_batch_manager.update_states(
691
678
  scheduler_output, self.get_mrope_input_positions_fn)
692
679
  if not scheduler_output.total_num_scheduled_tokens:
@@ -764,7 +751,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
764
751
  scheduler_output) as kv_connector_output:
765
752
  # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
766
753
  # but one of them would be `None`
767
-
768
754
  (self.kv_caches, hidden_states,
769
755
  aux_hidden_states) = self.model_fn(
770
756
  self.state,
@@ -775,8 +761,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
775
761
  input_positions,
776
762
  tuple(self.layer_name_to_kvcache_index.items()),
777
763
  lora_metadata,
764
+ intermediate_tensors,
765
+ self.is_first_rank,
766
+ self.is_last_rank,
778
767
  )
779
-
768
+ if not get_pp_group().is_last_rank:
769
+ assert isinstance(hidden_states, JaxIntermediateTensors)
770
+ hidden_states.kv_connector_output = kv_connector_output
771
+ return attn_metadata, hidden_states
780
772
  hidden_states = self._select_from_array_fn(hidden_states,
781
773
  logits_indices)
782
774
  logits = self.compute_logits_fn(
@@ -822,18 +814,31 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
822
814
 
823
815
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
824
816
  self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
817
+
818
+ # TODO(pooyam): Should we move this to `_prepare_inputs`?
819
+ if tpu_sampling_metadata.do_sampling:
820
+ self.rng_params_for_sampling, step_rng = jax.random.split(
821
+ self.rng_params_for_sampling)
822
+ else:
823
+ step_rng = self.rng_params_for_sampling
824
+
825
825
  if spec_decode_metadata is None:
826
826
  next_tokens = sample(
827
- self.rng_params_for_sampling,
827
+ step_rng,
828
828
  self.mesh,
829
829
  logits,
830
830
  tpu_sampling_metadata,
831
831
  )
832
832
  else:
833
+ if tpu_sampling_metadata.do_sampling:
834
+ bonus_rng, rejection_rng = jax.random.split(step_rng)
835
+ else:
836
+ bonus_rng = step_rng
837
+ rejection_rng = step_rng
833
838
  bonus_logits = self._select_from_array_fn(
834
839
  logits, spec_decode_metadata.bonus_logits_indices)
835
840
  bonus_token_ids = sample(
836
- self.rng_params_for_sampling,
841
+ bonus_rng,
837
842
  self.mesh,
838
843
  bonus_logits,
839
844
  tpu_sampling_metadata,
@@ -847,7 +852,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
847
852
  target_logits=target_logits,
848
853
  bonus_token_ids=bonus_token_ids,
849
854
  sampling_metadata=tpu_sampling_metadata,
850
- key=self.rng_params_for_sampling,
855
+ key=rejection_rng,
851
856
  )
852
857
 
853
858
  if tpu_sampling_metadata.logprobs:
@@ -1332,7 +1337,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1332
1337
  _request_distribution = []
1333
1338
  for dp_rank in range(dp_size):
1334
1339
  _num_reqs = num_req_per_dp_rank[dp_rank]
1335
- _request_distribution.append([0, 0, _num_reqs])
1340
+ # The batch has been reordered by _reorder_batch so decode requests come first
1341
+ # Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
1342
+ num_decode_in_dp_rank = 0
1343
+ for req_id in req_ids_dp[dp_rank]:
1344
+ if scheduler_output.num_scheduled_tokens[req_id] == 1:
1345
+ num_decode_in_dp_rank += 1
1346
+ _request_distribution.append(
1347
+ [num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
1336
1348
  request_distribution = np.array(_request_distribution).ravel()
1337
1349
 
1338
1350
  use_spec_decode = len(
@@ -1391,7 +1403,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1391
1403
  block_tables[
1392
1404
  req_offset:req_offset + _num_reqs, :self.
1393
1405
  max_num_blocks_per_req] = self.input_batch.block_table[
1394
- 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1406
+ kv_cache_gid].get_cpu_tensor()[req_indices_dp[dp_rank]]
1395
1407
  # Convert block_tables to 1D on cpu.
1396
1408
  block_tables = block_tables.reshape(-1)
1397
1409
  block_tables = device_array(
@@ -1706,3 +1718,34 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1706
1718
  mappings=mappings,
1707
1719
  transpose_keys=transpose_keys,
1708
1720
  shard=shard)
1721
+
1722
+ def get_intermediate_tensor_spec(self, num_tokens: int):
1723
+ jax_dtype = to_jax_dtype(self.dtype)
1724
+ num_padded_tokens = runner_utils.get_padded_token_len(
1725
+ self.num_tokens_paddings, num_tokens)
1726
+ sharding = NamedSharding(self.mesh, PartitionSpec())
1727
+ hidden_size = self.model_config.get_hidden_size()
1728
+ spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
1729
+ dtype=jax_dtype,
1730
+ sharding=sharding)
1731
+ tensor_spec = {"hidden_states": spec, "residual": spec}
1732
+ return tensor_spec
1733
+
1734
+ def get_uuid_for_jax_transfer(self,
1735
+ scheduler_output: "VllmSchedulerOutput",
1736
+ rank: int, step: int) -> int:
1737
+ '''
1738
+ Get a uuid for jax.transfer, here we use the hash of
1739
+ scheduler_output + counter_step + sender's rank
1740
+ '''
1741
+ scheduler_output_str = ""
1742
+ if not scheduler_output.num_scheduled_tokens:
1743
+ scheduler_output_str = "empty_batch"
1744
+ else:
1745
+ scheduler_output_str = str(
1746
+ sorted(scheduler_output.num_scheduled_tokens.items()))
1747
+ unique_str = f'{scheduler_output_str} {step} {rank}'
1748
+ import hashlib
1749
+ hasher = hashlib.sha1()
1750
+ hasher.update(unique_str.encode('utf-8'))
1751
+ return int.from_bytes(hasher.digest()[:8], 'big')
@@ -15,6 +15,7 @@ import jax
15
15
  from jax._src.interpreters import pxla
16
16
  from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
17
17
 
18
+ from tpu_inference import envs
18
19
  from tpu_inference.logger import init_logger
19
20
  from tpu_inference.runner.input_batch import InputBatch
20
21
 
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
306
307
  InferencePhase.BALANCED: False
307
308
  }
308
309
  self.default_profiling_options = jax.profiler.ProfileOptions()
309
- self.default_profiling_options.python_tracer_level = os.getenv(
310
- "PYTHON_TRACER_LEVEL", 0)
310
+ self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
311
311
 
312
312
  self.current_phase: str = ""
313
313