tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import copy
2
2
  import functools
3
- import os
3
+ import logging
4
4
  import random
5
5
  from contextlib import nullcontext
6
6
  from dataclasses import dataclass
@@ -10,17 +10,15 @@ import jax
10
10
  import jax.numpy as jnp
11
11
  import jaxtyping
12
12
  import numpy as np
13
- import torch
14
- import vllm.envs as envs
13
+ import vllm.envs as vllm_envs
15
14
  from flax import nnx
16
15
  from jax.experimental import mesh_utils
17
16
  from jax.sharding import NamedSharding, PartitionSpec
18
- from torchax.ops.mappings import j2t_dtype
19
17
  from vllm.config import VllmConfig
18
+ from vllm.distributed import get_pp_group
20
19
  from vllm.distributed.kv_transfer import (get_kv_transfer_group,
21
20
  has_kv_transfer_group)
22
21
  from vllm.forward_context import set_forward_context
23
- from vllm.sequence import IntermediateTensors
24
22
  from vllm.tasks import SupportedTask
25
23
  from vllm.utils.math_utils import cdiv
26
24
  from vllm.v1.core.sched.output import GrammarOutput
@@ -35,6 +33,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
35
33
  KVConnectorModelRunnerMixin
36
34
  from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
37
35
 
36
+ import tpu_inference.envs as envs
38
37
  from tpu_inference import utils as common_utils
39
38
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
39
  from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
@@ -48,6 +47,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
48
47
  TPUSupportedSamplingMetadata
49
48
  from tpu_inference.logger import init_logger
50
49
  from tpu_inference.models.common.model_loader import get_model
50
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
51
+ JaxIntermediateTensors
51
52
  from tpu_inference.models.jax.utils.weight_utils import (
52
53
  shard_put, transfer_state_with_mappings)
53
54
  from tpu_inference.runner import utils as runner_utils
@@ -64,10 +65,12 @@ from tpu_inference.runner.structured_decoding_manager import \
64
65
  StructuredDecodingManager
65
66
  from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
66
67
  from tpu_inference.utils import (device_array, make_optimized_mesh,
67
- time_function)
68
+ time_function, to_jax_dtype, to_torch_dtype)
68
69
 
69
70
  logger = init_logger(__name__)
70
71
 
72
+ logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
73
+
71
74
  INVALID_TOKEN_ID = -1
72
75
  # Smallest output size
73
76
  MIN_NUM_SEQS = 8
@@ -78,17 +81,6 @@ DUMMY_METADATA = AttentionMetadata(
78
81
  request_distribution=[0, 0, 0],
79
82
  )
80
83
 
81
- TPU_STR_DTYPE_TO_TORCH_DTYPE = {
82
- "half": torch.half,
83
- "bfloat16": torch.bfloat16,
84
- "float": torch.float,
85
- "fp8": torch.float8_e4m3fn,
86
- "fp8_e4m3": torch.float8_e4m3fn,
87
- "fp8_e5m2": torch.float8_e5m2,
88
- "int8": torch.int8,
89
- "uint8": torch.uint8,
90
- }
91
-
92
84
 
93
85
  class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
94
86
  """Holds asynchronous model output specifically from a TPU runner.
@@ -153,6 +145,7 @@ class ExecuteModelState:
153
145
  spec_decode_metadata: Optional[SpecDecodeMetadata]
154
146
  kv_connector_output: Optional[KVConnectorOutput]
155
147
  logits_indices_selector: Optional[List[int]] = None
148
+ padded_num_reqs: Optional[int] = None
156
149
 
157
150
 
158
151
  @functools.partial(jax.jit, donate_argnums=(0, 1, 2))
@@ -190,18 +183,28 @@ def _substitute_placeholder_token(
190
183
  return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
191
184
 
192
185
 
193
- def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
186
+ def _jax_logprobs_to_lists(logprobs_tensors,
187
+ logits_indices_selector=None,
188
+ cu_num_generated_tokens=None):
189
+ """Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
190
+ log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
191
+ logprobs_list = logprobs_tensors.logprobs.tolist()
192
+ selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
193
+
194
+ if logits_indices_selector is not None:
195
+ log_token_ids_list = [
196
+ log_token_ids_list[i] for i in logits_indices_selector
197
+ ]
198
+ logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
199
+ selected_token_ranks_list = [
200
+ selected_token_ranks_list[i] for i in logits_indices_selector
201
+ ]
202
+
194
203
  return LogprobsLists(
195
- logprob_token_ids=[
196
- logprobs_lists.logprob_token_ids[i]
197
- for i in logits_indices_selector
198
- ],
199
- logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
200
- sampled_token_ranks=[
201
- logprobs_lists.sampled_token_ranks[i]
202
- for i in logits_indices_selector
203
- ],
204
- cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
204
+ logprob_token_ids=np.asarray(log_token_ids_list),
205
+ logprobs=np.asarray(logprobs_list),
206
+ sampled_token_ranks=np.asarray(selected_token_ranks_list),
207
+ cu_num_generated_tokens=cu_num_generated_tokens,
205
208
  )
206
209
 
207
210
 
@@ -211,6 +214,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
211
214
  self,
212
215
  vllm_config: VllmConfig,
213
216
  devices: List[Any],
217
+ rank: int = 0,
218
+ is_first_rank: bool = True,
219
+ is_last_rank: bool = True,
214
220
  ):
215
221
  self.vllm_config = vllm_config
216
222
  self.model_config = vllm_config.model_config
@@ -229,6 +235,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
229
235
  self.maybe_forbid_compile = runner_utils.ForbidCompile(
230
236
  ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
231
237
  self.dp_size = self.vllm_config.sharding_config.total_dp_size
238
+ self.rank = rank
239
+ self.is_first_rank = is_first_rank
240
+ self.is_last_rank = is_last_rank
232
241
 
233
242
  self._init_random()
234
243
  self._init_mesh()
@@ -239,31 +248,21 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
239
248
 
240
249
  # Delegate functions to specific manager classes.
241
250
  self.compilation_manager = CompilationManager(self)
242
- self.speculative_decoding_manager = SpeculativeDecodingManager(self)
243
- self.structured_decoding_manager = StructuredDecodingManager(self)
251
+ if self.is_last_rank:
252
+ self.speculative_decoding_manager = SpeculativeDecodingManager(
253
+ self)
254
+ self.structured_decoding_manager = StructuredDecodingManager(self)
244
255
  self.kv_cache_manager = KVCacheManager(self)
245
256
  self.mm_manager = MultiModalManager(self)
246
257
  self.persistent_batch_manager = PersistentBatchManager(
247
258
  self.requests, self.input_batch, self.encoder_cache,
248
- self.uses_mrope, self.model_config)
259
+ self.uses_mrope, self.model_config, self.is_last_rank)
249
260
  self.lora_utils = LoraUtils(self)
250
261
 
251
- cache_config = self.cache_config
252
- if cache_config.cache_dtype == "auto":
253
- model_dtype = self.dtype
254
- if isinstance(model_dtype, str):
255
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
256
- elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
257
- self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
258
- elif isinstance(model_dtype, torch.dtype):
259
- self.kv_cache_dtype = model_dtype
260
- else:
261
- raise ValueError(
262
- "KV cache is unsupported for model_dtype of %s",
263
- model_dtype)
264
- else:
265
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
266
- cache_config.cache_dtype]
262
+ cache_dtype = self.cache_config.cache_dtype
263
+ if cache_dtype == "auto":
264
+ cache_dtype = self.dtype
265
+ self.kv_cache_dtype = to_torch_dtype(cache_dtype)
267
266
 
268
267
  self._pre_async_results: AsyncPreResults | None = None
269
268
  self._substitute_placeholder_token_fn = _substitute_placeholder_token
@@ -277,7 +276,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
277
276
  self.rng_key = jax.random.key(self.model_config.seed)
278
277
 
279
278
  def _init_mesh(self) -> None:
280
- if os.getenv("NEW_MODEL_DESIGN", False):
279
+ if envs.NEW_MODEL_DESIGN:
281
280
  self.mesh = self._create_new_model_mesh()
282
281
  else:
283
282
  # NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
@@ -288,7 +287,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
288
287
  logger.info(f"Init mesh | mesh={self.mesh}")
289
288
 
290
289
  def _create_new_model_mesh(self) -> jax.sharding.Mesh:
291
- num_slices = int(os.environ.get('NUM_SLICES', 1))
290
+ num_slices = envs.NUM_SLICES
292
291
 
293
292
  logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
294
293
  f"num_slices={num_slices}")
@@ -357,7 +356,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
357
356
  devices=self.devices)
358
357
 
359
358
  def _init_phased_profiling(self) -> None:
360
- self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
359
+ self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
361
360
  self.phase_based_profiler = None
362
361
  if self.phased_profiling_dir:
363
362
  self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
@@ -399,7 +398,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
399
398
  min_token_size=max(16, self.dp_size),
400
399
  max_token_size=scheduler_config.max_num_batched_tokens *
401
400
  self.dp_size,
402
- padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
401
+ padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
403
402
  self.num_tokens_paddings_per_dp = [
404
403
  padding // self.dp_size for padding in self.num_tokens_paddings
405
404
  ]
@@ -423,8 +422,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
423
422
 
424
423
  self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
425
424
  self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
426
- self.block_table_cpu = np.zeros(
427
- (self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
425
+ # Note: self.input_batch and self.block_tables_cpu are both initialized
426
+ # with only 1 block_size. For hybrid kv cache, it will be re-init
427
+ # in kv_cache_manager's maybe_reinitialize_input_batch.
428
+ self.block_tables_cpu = [
429
+ np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
430
+ dtype=np.int32)
431
+ ]
432
+
428
433
  self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
429
434
  dtype=np.int32)
430
435
  self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
@@ -458,9 +463,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
458
463
 
459
464
  # tensors for structured decoding
460
465
  self.vocab_size = self.model_config.get_vocab_size()
461
- if self.lora_config is not None:
462
- # lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
463
- self.vocab_size += self.lora_config.lora_extra_vocab_size
464
466
  self.grammar_bitmask_cpu = np.zeros(
465
467
  (self.max_num_reqs, cdiv(self.vocab_size, 32)),
466
468
  dtype=np.int32,
@@ -505,9 +507,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
505
507
 
506
508
  self.rng_params_for_sampling = nnx.Rngs(
507
509
  jax.random.key(self.model_config.seed)).params()
508
- self.is_multimodal_model = (self.model_config.is_multimodal_model
509
- and self.get_multimodal_embeddings_fn
510
- is not None)
510
+ self.is_multimodal_model = (
511
+ self.model_config.is_multimodal_model
512
+ and self.get_multimodal_embeddings_fn is not None and hasattr(
513
+ self.model_config.hf_config, "architectures"
514
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
515
+ and len(self.model_config.hf_config.architectures) >= 1
516
+ and self.model_config.hf_config.architectures[0]
517
+ != "Llama4ForConditionalGeneration")
511
518
 
512
519
  logger.info(f"Init model | "
513
520
  f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
@@ -520,6 +527,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
520
527
 
521
528
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
522
529
  self.kv_cache_config = kv_cache_config
530
+ self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
523
531
  self.kv_caches = []
524
532
  self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
525
533
  if has_kv_transfer_group():
@@ -532,12 +540,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
532
540
  def execute_model(
533
541
  self,
534
542
  scheduler_output: "VllmSchedulerOutput",
535
- intermediate_tensors: Optional[IntermediateTensors] = None,
536
- ) -> ModelRunnerOutput | None:
543
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
544
+ ) -> ModelRunnerOutput | JaxIntermediateTensors | None:
537
545
  if self.execute_model_state is not None:
538
546
  raise RuntimeError("State error: sample_tokens() must be called "
539
547
  "after execute_model() returns None.")
540
- _, output = self._execute_model(scheduler_output)
548
+ _, output = self._execute_model(scheduler_output, intermediate_tensors)
541
549
  return output
542
550
 
543
551
  def sample_tokens(
@@ -550,16 +558,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
550
558
 
551
559
  (scheduler_output, attn_metadata, input_ids, hidden_states, logits,
552
560
  aux_hidden_states, spec_decode_metadata, kv_connector_output,
553
- logits_indices_selector) = (
554
- self.execute_model_state.scheduler_output,
555
- self.execute_model_state.attn_metadata,
556
- self.execute_model_state.input_ids,
557
- self.execute_model_state.hidden_states,
558
- self.execute_model_state.logits,
559
- self.execute_model_state.aux_hidden_states,
560
- self.execute_model_state.spec_decode_metadata,
561
- self.execute_model_state.kv_connector_output,
562
- self.execute_model_state.logits_indices_selector)
561
+ logits_indices_selector,
562
+ padded_num_reqs) = (self.execute_model_state.scheduler_output,
563
+ self.execute_model_state.attn_metadata,
564
+ self.execute_model_state.input_ids,
565
+ self.execute_model_state.hidden_states,
566
+ self.execute_model_state.logits,
567
+ self.execute_model_state.aux_hidden_states,
568
+ self.execute_model_state.spec_decode_metadata,
569
+ self.execute_model_state.kv_connector_output,
570
+ self.execute_model_state.logits_indices_selector,
571
+ self.execute_model_state.padded_num_reqs)
563
572
  self.execute_model_state = None
564
573
 
565
574
  if grammar_output is not None:
@@ -573,12 +582,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
573
582
  logits,
574
583
  arange,
575
584
  )
576
- return self._sample_from_logits(scheduler_output, attn_metadata,
577
- input_ids, hidden_states, logits,
578
- aux_hidden_states,
579
- spec_decode_metadata,
580
- kv_connector_output,
581
- logits_indices_selector)
585
+ return self._sample_from_logits(
586
+ scheduler_output, attn_metadata, input_ids, hidden_states, logits,
587
+ aux_hidden_states, spec_decode_metadata, kv_connector_output,
588
+ logits_indices_selector, padded_num_reqs)
582
589
 
583
590
  def _modify_prev_results(self):
584
591
  # If copy to host has not been done, we just wait.
@@ -664,7 +671,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
664
671
  def _execute_model(
665
672
  self,
666
673
  scheduler_output: "VllmSchedulerOutput",
667
- ) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
674
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
675
+ ) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
676
+ | None]:
668
677
  self.persistent_batch_manager.update_states(
669
678
  scheduler_output, self.get_mrope_input_positions_fn)
670
679
  if not scheduler_output.total_num_scheduled_tokens:
@@ -687,13 +696,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
687
696
  # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
688
697
  (
689
698
  input_ids,
699
+ input_positions,
690
700
  attn_metadata,
691
701
  _,
692
702
  logits_indices,
693
703
  spec_decode_metadata,
694
704
  logits_indices_selector,
705
+ padded_num_reqs,
695
706
  ) = self._prepare_inputs(scheduler_output)
696
707
 
708
+ is_llama_guard_4 = (
709
+ hasattr(
710
+ self.model_config.hf_config, "architectures"
711
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
712
+ and len(self.model_config.hf_config.architectures) >= 1
713
+ and self.model_config.hf_config.architectures[0]
714
+ == "Llama4ForConditionalGeneration")
715
+
697
716
  # multi-modal support
698
717
  if self.is_multimodal_model:
699
718
  # Run the multimodal encoder if any.
@@ -701,6 +720,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
701
720
  self.mm_manager.execute_mm_encoder(scheduler_output)
702
721
  mm_embeds = self.mm_manager.gather_mm_embeddings(
703
722
  scheduler_output, input_ids.shape[0])
723
+ #TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
724
+ elif is_llama_guard_4 and any(
725
+ self.mm_manager.runner.requests[req_id].mm_features
726
+ for req_id in self.mm_manager.runner.input_batch.req_ids):
727
+ raise NotImplementedError(
728
+ "Llama Guard 4 (JAX) currently supports only text inputs. "
729
+ "Multimodal processing not yet implemented.")
704
730
  else:
705
731
  mm_embeds = []
706
732
 
@@ -725,7 +751,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
725
751
  scheduler_output) as kv_connector_output:
726
752
  # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
727
753
  # but one of them would be `None`
728
-
729
754
  (self.kv_caches, hidden_states,
730
755
  aux_hidden_states) = self.model_fn(
731
756
  self.state,
@@ -733,10 +758,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
733
758
  input_ids,
734
759
  attn_metadata,
735
760
  inputs_embeds,
761
+ input_positions,
736
762
  tuple(self.layer_name_to_kvcache_index.items()),
737
763
  lora_metadata,
764
+ intermediate_tensors,
765
+ self.is_first_rank,
766
+ self.is_last_rank,
738
767
  )
739
-
768
+ if not get_pp_group().is_last_rank:
769
+ assert isinstance(hidden_states, JaxIntermediateTensors)
770
+ hidden_states.kv_connector_output = kv_connector_output
771
+ return attn_metadata, hidden_states
740
772
  hidden_states = self._select_from_array_fn(hidden_states,
741
773
  logits_indices)
742
774
  logits = self.compute_logits_fn(
@@ -754,7 +786,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
754
786
  aux_hidden_states=aux_hidden_states,
755
787
  spec_decode_metadata=spec_decode_metadata,
756
788
  kv_connector_output=kv_connector_output,
757
- logits_indices_selector=logits_indices_selector)
789
+ logits_indices_selector=logits_indices_selector,
790
+ padded_num_reqs=padded_num_reqs)
758
791
  return attn_metadata, None
759
792
 
760
793
  def _sample_from_logits(
@@ -768,23 +801,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
768
801
  spec_decode_metadata: Optional[SpecDecodeMetadata],
769
802
  kv_connector_output: Optional[KVConnectorOutput],
770
803
  logits_indices_selector: Optional[List[int]] = None,
804
+ padded_num_reqs: Optional[int] = None,
771
805
  ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
772
- padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
773
- self.input_batch.num_reqs, self.max_num_reqs)
806
+ if padded_num_reqs is None:
807
+ padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
808
+ self.input_batch.num_reqs, self.max_num_reqs)
809
+
810
+ sharding = None
811
+ if self.dp_size > 1:
812
+ sharding = NamedSharding(self.mesh,
813
+ PartitionSpec(ShardingAxisName.ATTN_DATA))
814
+
774
815
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
775
- self.mesh, self.input_batch, padded_num_reqs)
816
+ self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
817
+
818
+ # TODO(pooyam): Should we move this to `_prepare_inputs`?
819
+ if tpu_sampling_metadata.do_sampling:
820
+ self.rng_params_for_sampling, step_rng = jax.random.split(
821
+ self.rng_params_for_sampling)
822
+ else:
823
+ step_rng = self.rng_params_for_sampling
824
+
776
825
  if spec_decode_metadata is None:
777
826
  next_tokens = sample(
778
- self.rng_params_for_sampling,
827
+ step_rng,
779
828
  self.mesh,
780
829
  logits,
781
830
  tpu_sampling_metadata,
782
831
  )
783
832
  else:
833
+ if tpu_sampling_metadata.do_sampling:
834
+ bonus_rng, rejection_rng = jax.random.split(step_rng)
835
+ else:
836
+ bonus_rng = step_rng
837
+ rejection_rng = step_rng
784
838
  bonus_logits = self._select_from_array_fn(
785
839
  logits, spec_decode_metadata.bonus_logits_indices)
786
840
  bonus_token_ids = sample(
787
- self.rng_params_for_sampling,
841
+ bonus_rng,
788
842
  self.mesh,
789
843
  bonus_logits,
790
844
  tpu_sampling_metadata,
@@ -798,7 +852,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
798
852
  target_logits=target_logits,
799
853
  bonus_token_ids=bonus_token_ids,
800
854
  sampling_metadata=tpu_sampling_metadata,
801
- key=self.rng_params_for_sampling,
855
+ key=rejection_rng,
802
856
  )
803
857
 
804
858
  if tpu_sampling_metadata.logprobs:
@@ -856,10 +910,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
856
910
 
857
911
  if logprobs is not None:
858
912
  # Map logprobs back to the pre-dp shuffling order
859
- logprobs_lists = logprobs.tolists()
860
- if logits_indices_selector is not None:
861
- logprobs_lists = _reorder_logits_indices(
862
- logprobs_lists, logits_indices_selector)
913
+ logprobs_lists = _jax_logprobs_to_lists(
914
+ logprobs, logits_indices_selector)
863
915
 
864
916
  else:
865
917
  logprobs_lists = None
@@ -929,10 +981,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
929
981
 
930
982
  if logprobs is not None:
931
983
  # Map logprobs back to the pre-dp shuffling order
932
- logprobs_lists = logprobs.tolists()
933
- if logits_indices_selector is not None:
934
- logprobs_lists = _reorder_logits_indices(
935
- logprobs_lists, logits_indices_selector)
984
+ logprobs_lists = _jax_logprobs_to_lists(logprobs,
985
+ logits_indices_selector)
936
986
  else:
937
987
  logprobs_lists = None
938
988
 
@@ -1280,16 +1330,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1280
1330
  mrope_positions = self.mrope_positions_cpu[:, :
1281
1331
  padded_total_num_scheduled_tokens]
1282
1332
 
1283
- block_tables = self.block_table_cpu[:self.max_num_reqs]
1284
- for dp_rank in range(dp_size):
1285
- req_offset = dp_rank * max_num_reqs_per_dp_rank
1286
- _num_reqs = num_req_per_dp_rank[dp_rank]
1287
-
1288
- block_tables[
1289
- req_offset:req_offset + _num_reqs, :self.
1290
- max_num_blocks_per_req] = self.input_batch.block_table[
1291
- 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1292
-
1293
1333
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
1294
1334
  dp_size]
1295
1335
  seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
@@ -1297,7 +1337,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1297
1337
  _request_distribution = []
1298
1338
  for dp_rank in range(dp_size):
1299
1339
  _num_reqs = num_req_per_dp_rank[dp_rank]
1300
- _request_distribution.append([0, 0, _num_reqs])
1340
+ # The batch has been reordered by _reorder_batch so decode requests come first
1341
+ # Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
1342
+ num_decode_in_dp_rank = 0
1343
+ for req_id in req_ids_dp[dp_rank]:
1344
+ if scheduler_output.num_scheduled_tokens[req_id] == 1:
1345
+ num_decode_in_dp_rank += 1
1346
+ _request_distribution.append(
1347
+ [num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
1301
1348
  request_distribution = np.array(_request_distribution).ravel()
1302
1349
 
1303
1350
  use_spec_decode = len(
@@ -1331,20 +1378,59 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1331
1378
  if self.uses_mrope:
1332
1379
  positions = mrope_positions
1333
1380
 
1334
- # Convert block_tables to 1D on cpu.
1335
- block_tables = block_tables.reshape(-1)
1336
-
1337
1381
  query_start_loc_cpu = query_start_loc
1338
1382
  logits_indices_cpu = logits_indices
1339
1383
  seq_lens_cpu = seq_lens
1340
1384
 
1341
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1342
- logits_indices, request_distribution) = device_array(
1385
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1386
+ request_distribution) = device_array(
1343
1387
  self.mesh,
1344
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1345
- logits_indices, request_distribution),
1388
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1389
+ request_distribution),
1346
1390
  sharding=data_parallel_attn_sharding,
1347
1391
  )
1392
+
1393
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1394
+ uniform_attention_metadata: AttentionMetadata = None
1395
+ for kv_cache_gid, kv_cache_group in enumerate(
1396
+ self.kv_cache_config.kv_cache_groups):
1397
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1398
+ max_num_reqs]
1399
+ for dp_rank in range(dp_size):
1400
+ req_offset = dp_rank * max_num_reqs_per_dp_rank
1401
+ _num_reqs = num_req_per_dp_rank[dp_rank]
1402
+
1403
+ block_tables[
1404
+ req_offset:req_offset + _num_reqs, :self.
1405
+ max_num_blocks_per_req] = self.input_batch.block_table[
1406
+ kv_cache_gid].get_cpu_tensor()[req_indices_dp[dp_rank]]
1407
+ # Convert block_tables to 1D on cpu.
1408
+ block_tables = block_tables.reshape(-1)
1409
+ block_tables = device_array(
1410
+ self.mesh,
1411
+ (block_tables),
1412
+ sharding=data_parallel_attn_sharding,
1413
+ )
1414
+
1415
+ attention_metadata_gid = AttentionMetadata(
1416
+ input_positions=positions,
1417
+ block_tables=block_tables,
1418
+ seq_lens=seq_lens,
1419
+ query_start_loc=query_start_loc,
1420
+ request_distribution=request_distribution,
1421
+ )
1422
+
1423
+ # This is for making these cpu buffers hidden during tracing
1424
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1425
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1426
+
1427
+ if not self.use_hybrid_kvcache:
1428
+ uniform_attention_metadata = attention_metadata_gid
1429
+ else:
1430
+ for layer_name in kv_cache_group.layer_names:
1431
+ attention_metadata_per_layer[
1432
+ layer_name] = attention_metadata_gid
1433
+
1348
1434
  # Async scheduling: substitute placeholder tokens for DP
1349
1435
  if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
1350
1436
  # Collect all token indices that need substitution across all DP ranks
@@ -1373,25 +1459,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1373
1459
  padded_total_num_scheduled_tokens,
1374
1460
  )
1375
1461
 
1376
- attention_metadata = AttentionMetadata(
1377
- input_positions=positions,
1378
- block_tables=block_tables,
1379
- seq_lens=seq_lens,
1380
- query_start_loc=query_start_loc,
1381
- request_distribution=request_distribution,
1382
- )
1383
-
1384
- # This is for making these cpu buffers hidden during tracing
1385
- attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1386
- attention_metadata.seq_lens_cpu = seq_lens_cpu
1387
-
1462
+ if self.use_hybrid_kvcache:
1463
+ attention_metadata = attention_metadata_per_layer
1464
+ else:
1465
+ attention_metadata = uniform_attention_metadata
1388
1466
  return (
1389
1467
  input_ids,
1468
+ positions,
1390
1469
  attention_metadata,
1391
1470
  sampling_metadata,
1392
1471
  logits_indices,
1393
1472
  spec_decode_metadata,
1394
1473
  logits_indices_selector,
1474
+ padded_num_reqs,
1395
1475
  )
1396
1476
 
1397
1477
  def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
@@ -1492,9 +1572,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1492
1572
  positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
1493
1573
  mrope_positions = self.mrope_positions_cpu[:, :
1494
1574
  padded_total_num_scheduled_tokens]
1495
- block_tables = self.block_table_cpu[:self.max_num_reqs]
1496
- block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
1497
- self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
1498
1575
 
1499
1576
  # TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
1500
1577
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
@@ -1523,16 +1600,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1523
1600
  self.mesh, self.input_batch, padded_num_reqs)
1524
1601
  if self.uses_mrope:
1525
1602
  positions = mrope_positions
1526
-
1527
- # Convert block_tables to 1D on cpu.
1528
- block_tables = block_tables.reshape(-1)
1529
-
1530
1603
  query_start_loc_cpu = query_start_loc
1531
1604
  seq_lens_cpu = seq_lens
1532
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1605
+
1606
+ (input_ids, positions, query_start_loc, seq_lens,
1533
1607
  logits_indices, request_distribution) = device_array(
1534
- self.mesh, (input_ids, positions, block_tables, query_start_loc,
1535
- seq_lens, logits_indices, request_distribution))
1608
+ self.mesh, (input_ids, positions, query_start_loc, seq_lens,
1609
+ logits_indices, request_distribution))
1610
+
1611
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1612
+ uniform_attention_metadata: AttentionMetadata = None
1613
+ for kv_cache_gid, kv_cache_group in enumerate(
1614
+ self.kv_cache_config.kv_cache_groups):
1615
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1616
+ max_num_reqs]
1617
+ block_tables[:num_reqs] = (
1618
+ self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
1619
+ [:num_reqs])
1620
+ # Convert block_tables to 1D on cpu.
1621
+ block_tables = block_tables.reshape(-1)
1622
+ block_tables = device_array(self.mesh, (block_tables))
1623
+
1624
+ attention_metadata_gid = AttentionMetadata(
1625
+ input_positions=positions,
1626
+ block_tables=block_tables,
1627
+ seq_lens=seq_lens,
1628
+ query_start_loc=query_start_loc,
1629
+ request_distribution=request_distribution)
1630
+ # This is for making these cpu buffers hidden during tracing
1631
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1632
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1633
+
1634
+ if not self.use_hybrid_kvcache:
1635
+ # all layers share the same attention metadata
1636
+ uniform_attention_metadata = attention_metadata_gid
1637
+ else:
1638
+ for layer_name in kv_cache_group.layer_names:
1639
+ attention_metadata_per_layer[
1640
+ layer_name] = attention_metadata_gid
1536
1641
 
1537
1642
  if self.scheduler_config.async_scheduling and len(
1538
1643
  token_in_tpu_cur_input_indices) > 0:
@@ -1545,20 +1650,15 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1545
1650
  self.lora_utils.set_active_loras(
1546
1651
  num_scheduled_tokens_per_req, total_num_scheduled_tokens,
1547
1652
  padded_total_num_scheduled_tokens)
1548
-
1549
- attention_metadata = AttentionMetadata(
1550
- input_positions=positions,
1551
- block_tables=block_tables,
1552
- seq_lens=seq_lens,
1553
- query_start_loc=query_start_loc,
1554
- request_distribution=request_distribution)
1555
-
1556
- # This is for making these cpu buffers hidden during tracing
1557
- attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1558
- attention_metadata.seq_lens_cpu = seq_lens_cpu
1559
1653
  logits_indices_selector = None
1560
- return (input_ids, attention_metadata, sampling_metadata,
1561
- logits_indices, spec_decode_metadata, logits_indices_selector)
1654
+
1655
+ if self.use_hybrid_kvcache:
1656
+ attention_metadata = attention_metadata_per_layer
1657
+ else:
1658
+ attention_metadata = uniform_attention_metadata
1659
+ return (input_ids, positions, attention_metadata, sampling_metadata,
1660
+ logits_indices, spec_decode_metadata, logits_indices_selector,
1661
+ padded_num_reqs)
1562
1662
 
1563
1663
  def _get_input_ids_embeds(self, input_ids: jax.Array,
1564
1664
  mm_embeds: list[jax.Array]):
@@ -1618,3 +1718,34 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1618
1718
  mappings=mappings,
1619
1719
  transpose_keys=transpose_keys,
1620
1720
  shard=shard)
1721
+
1722
+ def get_intermediate_tensor_spec(self, num_tokens: int):
1723
+ jax_dtype = to_jax_dtype(self.dtype)
1724
+ num_padded_tokens = runner_utils.get_padded_token_len(
1725
+ self.num_tokens_paddings, num_tokens)
1726
+ sharding = NamedSharding(self.mesh, PartitionSpec())
1727
+ hidden_size = self.model_config.get_hidden_size()
1728
+ spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
1729
+ dtype=jax_dtype,
1730
+ sharding=sharding)
1731
+ tensor_spec = {"hidden_states": spec, "residual": spec}
1732
+ return tensor_spec
1733
+
1734
+ def get_uuid_for_jax_transfer(self,
1735
+ scheduler_output: "VllmSchedulerOutput",
1736
+ rank: int, step: int) -> int:
1737
+ '''
1738
+ Get a uuid for jax.transfer, here we use the hash of
1739
+ scheduler_output + counter_step + sender's rank
1740
+ '''
1741
+ scheduler_output_str = ""
1742
+ if not scheduler_output.num_scheduled_tokens:
1743
+ scheduler_output_str = "empty_batch"
1744
+ else:
1745
+ scheduler_output_str = str(
1746
+ sorted(scheduler_output.num_scheduled_tokens.items()))
1747
+ unique_str = f'{scheduler_output_str} {step} {rank}'
1748
+ import hashlib
1749
+ hasher = hashlib.sha1()
1750
+ hasher.update(unique_str.encode('utf-8'))
1751
+ return int.from_bytes(hasher.digest()[:8], 'big')