tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (54) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tests/test_envs.py +32 -11
  5. tests/test_utils.py +1 -2
  6. tpu_inference/__init__.py +22 -3
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +3 -4
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +61 -8
  11. tpu_inference/executors/ray_distributed_executor.py +31 -11
  12. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
  15. tpu_inference/layers/common/attention_interface.py +7 -1
  16. tpu_inference/layers/common/sharding.py +5 -5
  17. tpu_inference/layers/vllm/fused_moe.py +74 -25
  18. tpu_inference/layers/vllm/quantization/common.py +6 -1
  19. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
  20. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  21. tpu_inference/layers/vllm/sharding.py +2 -2
  22. tpu_inference/lora/torch_punica_tpu.py +1 -2
  23. tpu_inference/models/common/model_loader.py +45 -11
  24. tpu_inference/models/jax/llama3.py +2 -1
  25. tpu_inference/models/jax/llama_eagle3.py +8 -5
  26. tpu_inference/models/jax/llama_guard_4.py +361 -0
  27. tpu_inference/models/jax/qwen2.py +2 -1
  28. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  29. tpu_inference/models/jax/qwen3.py +2 -1
  30. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  31. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  32. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
  33. tpu_inference/platforms/tpu_platform.py +28 -22
  34. tpu_inference/runner/compilation_manager.py +144 -59
  35. tpu_inference/runner/kv_cache_manager.py +17 -18
  36. tpu_inference/runner/persistent_batch_manager.py +40 -2
  37. tpu_inference/runner/structured_decoding_manager.py +2 -3
  38. tpu_inference/runner/tpu_runner.py +271 -147
  39. tpu_inference/runner/utils.py +2 -2
  40. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  41. tpu_inference/tpu_info.py +4 -3
  42. tpu_inference/utils.py +36 -13
  43. tpu_inference/worker/tpu_worker.py +162 -25
  44. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
  45. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
  46. tpu_inference/mock/__init__.py +0 -0
  47. tpu_inference/mock/vllm_config_utils.py +0 -28
  48. tpu_inference/mock/vllm_envs.py +0 -1219
  49. tpu_inference/mock/vllm_logger.py +0 -212
  50. tpu_inference/mock/vllm_logging_utils.py +0 -15
  51. tpu_inference/models/jax/phi3.py +0 -376
  52. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  53. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  54. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
@@ -10,17 +10,16 @@ import jax
10
10
  import jax.numpy as jnp
11
11
  import jaxtyping
12
12
  import numpy as np
13
- import torch
14
- import vllm.envs as envs
13
+ import vllm.envs as vllm_envs
15
14
  from flax import nnx
16
15
  from jax.experimental import mesh_utils
17
16
  from jax.sharding import NamedSharding, PartitionSpec
18
- from torchax.ops.mappings import j2t_dtype
17
+ from torchax.ops.mappings import t2j_dtype
19
18
  from vllm.config import VllmConfig
19
+ from vllm.distributed import get_pp_group
20
20
  from vllm.distributed.kv_transfer import (get_kv_transfer_group,
21
21
  has_kv_transfer_group)
22
22
  from vllm.forward_context import set_forward_context
23
- from vllm.sequence import IntermediateTensors
24
23
  from vllm.tasks import SupportedTask
25
24
  from vllm.utils.math_utils import cdiv
26
25
  from vllm.v1.core.sched.output import GrammarOutput
@@ -35,6 +34,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
35
34
  KVConnectorModelRunnerMixin
36
35
  from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
37
36
 
37
+ import tpu_inference.envs as envs
38
38
  from tpu_inference import utils as common_utils
39
39
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
40
  from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
@@ -48,6 +48,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
48
48
  TPUSupportedSamplingMetadata
49
49
  from tpu_inference.logger import init_logger
50
50
  from tpu_inference.models.common.model_loader import get_model
51
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
52
+ JaxIntermediateTensors
51
53
  from tpu_inference.models.jax.utils.weight_utils import (
52
54
  shard_put, transfer_state_with_mappings)
53
55
  from tpu_inference.runner import utils as runner_utils
@@ -64,7 +66,7 @@ from tpu_inference.runner.structured_decoding_manager import \
64
66
  StructuredDecodingManager
65
67
  from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
66
68
  from tpu_inference.utils import (device_array, make_optimized_mesh,
67
- time_function)
69
+ time_function, to_torch_dtype)
68
70
 
69
71
  logger = init_logger(__name__)
70
72
 
@@ -78,17 +80,6 @@ DUMMY_METADATA = AttentionMetadata(
78
80
  request_distribution=[0, 0, 0],
79
81
  )
80
82
 
81
- TPU_STR_DTYPE_TO_TORCH_DTYPE = {
82
- "half": torch.half,
83
- "bfloat16": torch.bfloat16,
84
- "float": torch.float,
85
- "fp8": torch.float8_e4m3fn,
86
- "fp8_e4m3": torch.float8_e4m3fn,
87
- "fp8_e5m2": torch.float8_e5m2,
88
- "int8": torch.int8,
89
- "uint8": torch.uint8,
90
- }
91
-
92
83
 
93
84
  class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
94
85
  """Holds asynchronous model output specifically from a TPU runner.
@@ -153,6 +144,7 @@ class ExecuteModelState:
153
144
  spec_decode_metadata: Optional[SpecDecodeMetadata]
154
145
  kv_connector_output: Optional[KVConnectorOutput]
155
146
  logits_indices_selector: Optional[List[int]] = None
147
+ padded_num_reqs: Optional[int] = None
156
148
 
157
149
 
158
150
  @functools.partial(jax.jit, donate_argnums=(0, 1, 2))
@@ -190,18 +182,28 @@ def _substitute_placeholder_token(
190
182
  return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
191
183
 
192
184
 
193
- def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
185
+ def _jax_logprobs_to_lists(logprobs_tensors,
186
+ logits_indices_selector=None,
187
+ cu_num_generated_tokens=None):
188
+ """Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
189
+ log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
190
+ logprobs_list = logprobs_tensors.logprobs.tolist()
191
+ selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
192
+
193
+ if logits_indices_selector is not None:
194
+ log_token_ids_list = [
195
+ log_token_ids_list[i] for i in logits_indices_selector
196
+ ]
197
+ logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
198
+ selected_token_ranks_list = [
199
+ selected_token_ranks_list[i] for i in logits_indices_selector
200
+ ]
201
+
194
202
  return LogprobsLists(
195
- logprob_token_ids=[
196
- logprobs_lists.logprob_token_ids[i]
197
- for i in logits_indices_selector
198
- ],
199
- logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
200
- sampled_token_ranks=[
201
- logprobs_lists.sampled_token_ranks[i]
202
- for i in logits_indices_selector
203
- ],
204
- cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
203
+ logprob_token_ids=np.asarray(log_token_ids_list),
204
+ logprobs=np.asarray(logprobs_list),
205
+ sampled_token_ranks=np.asarray(selected_token_ranks_list),
206
+ cu_num_generated_tokens=cu_num_generated_tokens,
205
207
  )
206
208
 
207
209
 
@@ -211,6 +213,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
211
213
  self,
212
214
  vllm_config: VllmConfig,
213
215
  devices: List[Any],
216
+ rank: int = 0,
217
+ is_first_rank: bool = True,
218
+ is_last_rank: bool = True,
214
219
  ):
215
220
  self.vllm_config = vllm_config
216
221
  self.model_config = vllm_config.model_config
@@ -229,6 +234,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
229
234
  self.maybe_forbid_compile = runner_utils.ForbidCompile(
230
235
  ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
231
236
  self.dp_size = self.vllm_config.sharding_config.total_dp_size
237
+ self.rank = rank
238
+ self.is_first_rank = is_first_rank
239
+ self.is_last_rank = is_last_rank
232
240
 
233
241
  self._init_random()
234
242
  self._init_mesh()
@@ -239,31 +247,21 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
239
247
 
240
248
  # Delegate functions to specific manager classes.
241
249
  self.compilation_manager = CompilationManager(self)
242
- self.speculative_decoding_manager = SpeculativeDecodingManager(self)
243
- self.structured_decoding_manager = StructuredDecodingManager(self)
250
+ if self.is_last_rank:
251
+ self.speculative_decoding_manager = SpeculativeDecodingManager(
252
+ self)
253
+ self.structured_decoding_manager = StructuredDecodingManager(self)
244
254
  self.kv_cache_manager = KVCacheManager(self)
245
255
  self.mm_manager = MultiModalManager(self)
246
256
  self.persistent_batch_manager = PersistentBatchManager(
247
257
  self.requests, self.input_batch, self.encoder_cache,
248
- self.uses_mrope, self.model_config)
258
+ self.uses_mrope, self.model_config, self.is_last_rank)
249
259
  self.lora_utils = LoraUtils(self)
250
260
 
251
- cache_config = self.cache_config
252
- if cache_config.cache_dtype == "auto":
253
- model_dtype = self.dtype
254
- if isinstance(model_dtype, str):
255
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
256
- elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
257
- self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
258
- elif isinstance(model_dtype, torch.dtype):
259
- self.kv_cache_dtype = model_dtype
260
- else:
261
- raise ValueError(
262
- "KV cache is unsupported for model_dtype of %s",
263
- model_dtype)
264
- else:
265
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
266
- cache_config.cache_dtype]
261
+ cache_dtype = self.cache_config.cache_dtype
262
+ if cache_dtype == "auto":
263
+ cache_dtype = self.dtype
264
+ self.kv_cache_dtype = to_torch_dtype(cache_dtype)
267
265
 
268
266
  self._pre_async_results: AsyncPreResults | None = None
269
267
  self._substitute_placeholder_token_fn = _substitute_placeholder_token
@@ -277,7 +275,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
277
275
  self.rng_key = jax.random.key(self.model_config.seed)
278
276
 
279
277
  def _init_mesh(self) -> None:
280
- if os.getenv("NEW_MODEL_DESIGN", False):
278
+ if envs.NEW_MODEL_DESIGN:
281
279
  self.mesh = self._create_new_model_mesh()
282
280
  else:
283
281
  # NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
@@ -288,7 +286,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
288
286
  logger.info(f"Init mesh | mesh={self.mesh}")
289
287
 
290
288
  def _create_new_model_mesh(self) -> jax.sharding.Mesh:
291
- num_slices = int(os.environ.get('NUM_SLICES', 1))
289
+ num_slices = envs.NUM_SLICES
292
290
 
293
291
  logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
294
292
  f"num_slices={num_slices}")
@@ -357,7 +355,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
357
355
  devices=self.devices)
358
356
 
359
357
  def _init_phased_profiling(self) -> None:
360
- self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
358
+ self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
361
359
  self.phase_based_profiler = None
362
360
  if self.phased_profiling_dir:
363
361
  self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
@@ -399,7 +397,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
399
397
  min_token_size=max(16, self.dp_size),
400
398
  max_token_size=scheduler_config.max_num_batched_tokens *
401
399
  self.dp_size,
402
- padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
400
+ padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
403
401
  self.num_tokens_paddings_per_dp = [
404
402
  padding // self.dp_size for padding in self.num_tokens_paddings
405
403
  ]
@@ -423,8 +421,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
423
421
 
424
422
  self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
425
423
  self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
426
- self.block_table_cpu = np.zeros(
427
- (self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
424
+ # Note: self.input_batch and self.block_tables_cpu are both initialized
425
+ # with only 1 block_size. For hybrid kv cache, it will be re-init
426
+ # in kv_cache_manager's maybe_reinitialize_input_batch.
427
+ self.block_tables_cpu = [
428
+ np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
429
+ dtype=np.int32)
430
+ ]
431
+
428
432
  self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
429
433
  dtype=np.int32)
430
434
  self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
@@ -458,9 +462,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
458
462
 
459
463
  # tensors for structured decoding
460
464
  self.vocab_size = self.model_config.get_vocab_size()
461
- if self.lora_config is not None:
462
- # lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
463
- self.vocab_size += self.lora_config.lora_extra_vocab_size
464
465
  self.grammar_bitmask_cpu = np.zeros(
465
466
  (self.max_num_reqs, cdiv(self.vocab_size, 32)),
466
467
  dtype=np.int32,
@@ -505,9 +506,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
505
506
 
506
507
  self.rng_params_for_sampling = nnx.Rngs(
507
508
  jax.random.key(self.model_config.seed)).params()
508
- self.is_multimodal_model = (self.model_config.is_multimodal_model
509
- and self.get_multimodal_embeddings_fn
510
- is not None)
509
+ self.is_multimodal_model = (
510
+ self.model_config.is_multimodal_model
511
+ and self.get_multimodal_embeddings_fn is not None and hasattr(
512
+ self.model_config.hf_config, "architectures"
513
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
514
+ and len(self.model_config.hf_config.architectures) >= 1
515
+ and self.model_config.hf_config.architectures[0]
516
+ != "Llama4ForConditionalGeneration")
511
517
 
512
518
  logger.info(f"Init model | "
513
519
  f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
@@ -520,6 +526,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
520
526
 
521
527
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
522
528
  self.kv_cache_config = kv_cache_config
529
+ self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
523
530
  self.kv_caches = []
524
531
  self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
525
532
  if has_kv_transfer_group():
@@ -532,12 +539,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
532
539
  def execute_model(
533
540
  self,
534
541
  scheduler_output: "VllmSchedulerOutput",
535
- intermediate_tensors: Optional[IntermediateTensors] = None,
536
- ) -> ModelRunnerOutput | None:
542
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
543
+ ) -> ModelRunnerOutput | JaxIntermediateTensors | None:
537
544
  if self.execute_model_state is not None:
538
545
  raise RuntimeError("State error: sample_tokens() must be called "
539
546
  "after execute_model() returns None.")
540
- _, output = self._execute_model(scheduler_output)
547
+ _, output = self._execute_model(scheduler_output, intermediate_tensors)
541
548
  return output
542
549
 
543
550
  def sample_tokens(
@@ -550,16 +557,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
550
557
 
551
558
  (scheduler_output, attn_metadata, input_ids, hidden_states, logits,
552
559
  aux_hidden_states, spec_decode_metadata, kv_connector_output,
553
- logits_indices_selector) = (
554
- self.execute_model_state.scheduler_output,
555
- self.execute_model_state.attn_metadata,
556
- self.execute_model_state.input_ids,
557
- self.execute_model_state.hidden_states,
558
- self.execute_model_state.logits,
559
- self.execute_model_state.aux_hidden_states,
560
- self.execute_model_state.spec_decode_metadata,
561
- self.execute_model_state.kv_connector_output,
562
- self.execute_model_state.logits_indices_selector)
560
+ logits_indices_selector,
561
+ padded_num_reqs) = (self.execute_model_state.scheduler_output,
562
+ self.execute_model_state.attn_metadata,
563
+ self.execute_model_state.input_ids,
564
+ self.execute_model_state.hidden_states,
565
+ self.execute_model_state.logits,
566
+ self.execute_model_state.aux_hidden_states,
567
+ self.execute_model_state.spec_decode_metadata,
568
+ self.execute_model_state.kv_connector_output,
569
+ self.execute_model_state.logits_indices_selector,
570
+ self.execute_model_state.padded_num_reqs)
563
571
  self.execute_model_state = None
564
572
 
565
573
  if grammar_output is not None:
@@ -573,12 +581,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
573
581
  logits,
574
582
  arange,
575
583
  )
576
- return self._sample_from_logits(scheduler_output, attn_metadata,
577
- input_ids, hidden_states, logits,
578
- aux_hidden_states,
579
- spec_decode_metadata,
580
- kv_connector_output,
581
- logits_indices_selector)
584
+ return self._sample_from_logits(
585
+ scheduler_output, attn_metadata, input_ids, hidden_states, logits,
586
+ aux_hidden_states, spec_decode_metadata, kv_connector_output,
587
+ logits_indices_selector, padded_num_reqs)
582
588
 
583
589
  def _modify_prev_results(self):
584
590
  # If copy to host has not been done, we just wait.
@@ -664,7 +670,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
664
670
  def _execute_model(
665
671
  self,
666
672
  scheduler_output: "VllmSchedulerOutput",
667
- ) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
673
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
674
+ ) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
675
+ | None]:
668
676
  self.persistent_batch_manager.update_states(
669
677
  scheduler_output, self.get_mrope_input_positions_fn)
670
678
  if not scheduler_output.total_num_scheduled_tokens:
@@ -687,13 +695,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
687
695
  # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
688
696
  (
689
697
  input_ids,
698
+ input_positions,
690
699
  attn_metadata,
691
700
  _,
692
701
  logits_indices,
693
702
  spec_decode_metadata,
694
703
  logits_indices_selector,
704
+ padded_num_reqs,
695
705
  ) = self._prepare_inputs(scheduler_output)
696
706
 
707
+ is_llama_guard_4 = (
708
+ hasattr(
709
+ self.model_config.hf_config, "architectures"
710
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
711
+ and len(self.model_config.hf_config.architectures) >= 1
712
+ and self.model_config.hf_config.architectures[0]
713
+ == "Llama4ForConditionalGeneration")
714
+
697
715
  # multi-modal support
698
716
  if self.is_multimodal_model:
699
717
  # Run the multimodal encoder if any.
@@ -701,6 +719,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
701
719
  self.mm_manager.execute_mm_encoder(scheduler_output)
702
720
  mm_embeds = self.mm_manager.gather_mm_embeddings(
703
721
  scheduler_output, input_ids.shape[0])
722
+ #TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
723
+ elif is_llama_guard_4 and any(
724
+ self.mm_manager.runner.requests[req_id].mm_features
725
+ for req_id in self.mm_manager.runner.input_batch.req_ids):
726
+ raise NotImplementedError(
727
+ "Llama Guard 4 (JAX) currently supports only text inputs. "
728
+ "Multimodal processing not yet implemented.")
704
729
  else:
705
730
  mm_embeds = []
706
731
 
@@ -725,7 +750,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
725
750
  scheduler_output) as kv_connector_output:
726
751
  # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
727
752
  # but one of them would be `None`
728
-
729
753
  (self.kv_caches, hidden_states,
730
754
  aux_hidden_states) = self.model_fn(
731
755
  self.state,
@@ -733,10 +757,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
733
757
  input_ids,
734
758
  attn_metadata,
735
759
  inputs_embeds,
760
+ input_positions,
736
761
  tuple(self.layer_name_to_kvcache_index.items()),
737
762
  lora_metadata,
763
+ intermediate_tensors,
764
+ self.is_first_rank,
765
+ self.is_last_rank,
738
766
  )
739
-
767
+ if not get_pp_group().is_last_rank:
768
+ assert isinstance(hidden_states, JaxIntermediateTensors)
769
+ hidden_states.kv_connector_output = kv_connector_output
770
+ return attn_metadata, hidden_states
740
771
  hidden_states = self._select_from_array_fn(hidden_states,
741
772
  logits_indices)
742
773
  logits = self.compute_logits_fn(
@@ -754,7 +785,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
754
785
  aux_hidden_states=aux_hidden_states,
755
786
  spec_decode_metadata=spec_decode_metadata,
756
787
  kv_connector_output=kv_connector_output,
757
- logits_indices_selector=logits_indices_selector)
788
+ logits_indices_selector=logits_indices_selector,
789
+ padded_num_reqs=padded_num_reqs)
758
790
  return attn_metadata, None
759
791
 
760
792
  def _sample_from_logits(
@@ -768,23 +800,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
768
800
  spec_decode_metadata: Optional[SpecDecodeMetadata],
769
801
  kv_connector_output: Optional[KVConnectorOutput],
770
802
  logits_indices_selector: Optional[List[int]] = None,
803
+ padded_num_reqs: Optional[int] = None,
771
804
  ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
772
- padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
773
- self.input_batch.num_reqs, self.max_num_reqs)
805
+ if padded_num_reqs is None:
806
+ padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
807
+ self.input_batch.num_reqs, self.max_num_reqs)
808
+
809
+ sharding = None
810
+ if self.dp_size > 1:
811
+ sharding = NamedSharding(self.mesh,
812
+ PartitionSpec(ShardingAxisName.ATTN_DATA))
813
+
774
814
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
775
- self.mesh, self.input_batch, padded_num_reqs)
815
+ self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
816
+
817
+ # TODO(pooyam): Should we move this to `_prepare_inputs`?
818
+ if tpu_sampling_metadata.do_sampling:
819
+ self.rng_params_for_sampling, step_rng = jax.random.split(
820
+ self.rng_params_for_sampling)
821
+ else:
822
+ step_rng = self.rng_params_for_sampling
823
+
776
824
  if spec_decode_metadata is None:
777
825
  next_tokens = sample(
778
- self.rng_params_for_sampling,
826
+ step_rng,
779
827
  self.mesh,
780
828
  logits,
781
829
  tpu_sampling_metadata,
782
830
  )
783
831
  else:
832
+ if tpu_sampling_metadata.do_sampling:
833
+ bonus_rng, rejection_rng = jax.random.split(step_rng)
834
+ else:
835
+ bonus_rng = step_rng
836
+ rejection_rng = step_rng
784
837
  bonus_logits = self._select_from_array_fn(
785
838
  logits, spec_decode_metadata.bonus_logits_indices)
786
839
  bonus_token_ids = sample(
787
- self.rng_params_for_sampling,
840
+ bonus_rng,
788
841
  self.mesh,
789
842
  bonus_logits,
790
843
  tpu_sampling_metadata,
@@ -798,7 +851,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
798
851
  target_logits=target_logits,
799
852
  bonus_token_ids=bonus_token_ids,
800
853
  sampling_metadata=tpu_sampling_metadata,
801
- key=self.rng_params_for_sampling,
854
+ key=rejection_rng,
802
855
  )
803
856
 
804
857
  if tpu_sampling_metadata.logprobs:
@@ -856,10 +909,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
856
909
 
857
910
  if logprobs is not None:
858
911
  # Map logprobs back to the pre-dp shuffling order
859
- logprobs_lists = logprobs.tolists()
860
- if logits_indices_selector is not None:
861
- logprobs_lists = _reorder_logits_indices(
862
- logprobs_lists, logits_indices_selector)
912
+ logprobs_lists = _jax_logprobs_to_lists(
913
+ logprobs, logits_indices_selector)
863
914
 
864
915
  else:
865
916
  logprobs_lists = None
@@ -929,10 +980,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
929
980
 
930
981
  if logprobs is not None:
931
982
  # Map logprobs back to the pre-dp shuffling order
932
- logprobs_lists = logprobs.tolists()
933
- if logits_indices_selector is not None:
934
- logprobs_lists = _reorder_logits_indices(
935
- logprobs_lists, logits_indices_selector)
983
+ logprobs_lists = _jax_logprobs_to_lists(logprobs,
984
+ logits_indices_selector)
936
985
  else:
937
986
  logprobs_lists = None
938
987
 
@@ -1280,16 +1329,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1280
1329
  mrope_positions = self.mrope_positions_cpu[:, :
1281
1330
  padded_total_num_scheduled_tokens]
1282
1331
 
1283
- block_tables = self.block_table_cpu[:self.max_num_reqs]
1284
- for dp_rank in range(dp_size):
1285
- req_offset = dp_rank * max_num_reqs_per_dp_rank
1286
- _num_reqs = num_req_per_dp_rank[dp_rank]
1287
-
1288
- block_tables[
1289
- req_offset:req_offset + _num_reqs, :self.
1290
- max_num_blocks_per_req] = self.input_batch.block_table[
1291
- 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1292
-
1293
1332
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
1294
1333
  dp_size]
1295
1334
  seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
@@ -1331,20 +1370,59 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1331
1370
  if self.uses_mrope:
1332
1371
  positions = mrope_positions
1333
1372
 
1334
- # Convert block_tables to 1D on cpu.
1335
- block_tables = block_tables.reshape(-1)
1336
-
1337
1373
  query_start_loc_cpu = query_start_loc
1338
1374
  logits_indices_cpu = logits_indices
1339
1375
  seq_lens_cpu = seq_lens
1340
1376
 
1341
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1342
- logits_indices, request_distribution) = device_array(
1377
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1378
+ request_distribution) = device_array(
1343
1379
  self.mesh,
1344
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1345
- logits_indices, request_distribution),
1380
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1381
+ request_distribution),
1346
1382
  sharding=data_parallel_attn_sharding,
1347
1383
  )
1384
+
1385
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1386
+ uniform_attention_metadata: AttentionMetadata = None
1387
+ for kv_cache_gid, kv_cache_group in enumerate(
1388
+ self.kv_cache_config.kv_cache_groups):
1389
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1390
+ max_num_reqs]
1391
+ for dp_rank in range(dp_size):
1392
+ req_offset = dp_rank * max_num_reqs_per_dp_rank
1393
+ _num_reqs = num_req_per_dp_rank[dp_rank]
1394
+
1395
+ block_tables[
1396
+ req_offset:req_offset + _num_reqs, :self.
1397
+ max_num_blocks_per_req] = self.input_batch.block_table[
1398
+ 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1399
+ # Convert block_tables to 1D on cpu.
1400
+ block_tables = block_tables.reshape(-1)
1401
+ block_tables = device_array(
1402
+ self.mesh,
1403
+ (block_tables),
1404
+ sharding=data_parallel_attn_sharding,
1405
+ )
1406
+
1407
+ attention_metadata_gid = AttentionMetadata(
1408
+ input_positions=positions,
1409
+ block_tables=block_tables,
1410
+ seq_lens=seq_lens,
1411
+ query_start_loc=query_start_loc,
1412
+ request_distribution=request_distribution,
1413
+ )
1414
+
1415
+ # This is for making these cpu buffers hidden during tracing
1416
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1417
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1418
+
1419
+ if not self.use_hybrid_kvcache:
1420
+ uniform_attention_metadata = attention_metadata_gid
1421
+ else:
1422
+ for layer_name in kv_cache_group.layer_names:
1423
+ attention_metadata_per_layer[
1424
+ layer_name] = attention_metadata_gid
1425
+
1348
1426
  # Async scheduling: substitute placeholder tokens for DP
1349
1427
  if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
1350
1428
  # Collect all token indices that need substitution across all DP ranks
@@ -1373,25 +1451,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1373
1451
  padded_total_num_scheduled_tokens,
1374
1452
  )
1375
1453
 
1376
- attention_metadata = AttentionMetadata(
1377
- input_positions=positions,
1378
- block_tables=block_tables,
1379
- seq_lens=seq_lens,
1380
- query_start_loc=query_start_loc,
1381
- request_distribution=request_distribution,
1382
- )
1383
-
1384
- # This is for making these cpu buffers hidden during tracing
1385
- attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1386
- attention_metadata.seq_lens_cpu = seq_lens_cpu
1387
-
1454
+ if self.use_hybrid_kvcache:
1455
+ attention_metadata = attention_metadata_per_layer
1456
+ else:
1457
+ attention_metadata = uniform_attention_metadata
1388
1458
  return (
1389
1459
  input_ids,
1460
+ positions,
1390
1461
  attention_metadata,
1391
1462
  sampling_metadata,
1392
1463
  logits_indices,
1393
1464
  spec_decode_metadata,
1394
1465
  logits_indices_selector,
1466
+ padded_num_reqs,
1395
1467
  )
1396
1468
 
1397
1469
  def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
@@ -1492,9 +1564,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1492
1564
  positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
1493
1565
  mrope_positions = self.mrope_positions_cpu[:, :
1494
1566
  padded_total_num_scheduled_tokens]
1495
- block_tables = self.block_table_cpu[:self.max_num_reqs]
1496
- block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
1497
- self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
1498
1567
 
1499
1568
  # TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
1500
1569
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
@@ -1523,16 +1592,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1523
1592
  self.mesh, self.input_batch, padded_num_reqs)
1524
1593
  if self.uses_mrope:
1525
1594
  positions = mrope_positions
1526
-
1527
- # Convert block_tables to 1D on cpu.
1528
- block_tables = block_tables.reshape(-1)
1529
-
1530
1595
  query_start_loc_cpu = query_start_loc
1531
1596
  seq_lens_cpu = seq_lens
1532
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1597
+
1598
+ (input_ids, positions, query_start_loc, seq_lens,
1533
1599
  logits_indices, request_distribution) = device_array(
1534
- self.mesh, (input_ids, positions, block_tables, query_start_loc,
1535
- seq_lens, logits_indices, request_distribution))
1600
+ self.mesh, (input_ids, positions, query_start_loc, seq_lens,
1601
+ logits_indices, request_distribution))
1602
+
1603
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1604
+ uniform_attention_metadata: AttentionMetadata = None
1605
+ for kv_cache_gid, kv_cache_group in enumerate(
1606
+ self.kv_cache_config.kv_cache_groups):
1607
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1608
+ max_num_reqs]
1609
+ block_tables[:num_reqs] = (
1610
+ self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
1611
+ [:num_reqs])
1612
+ # Convert block_tables to 1D on cpu.
1613
+ block_tables = block_tables.reshape(-1)
1614
+ block_tables = device_array(self.mesh, (block_tables))
1615
+
1616
+ attention_metadata_gid = AttentionMetadata(
1617
+ input_positions=positions,
1618
+ block_tables=block_tables,
1619
+ seq_lens=seq_lens,
1620
+ query_start_loc=query_start_loc,
1621
+ request_distribution=request_distribution)
1622
+ # This is for making these cpu buffers hidden during tracing
1623
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1624
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1625
+
1626
+ if not self.use_hybrid_kvcache:
1627
+ # all layers share the same attention metadata
1628
+ uniform_attention_metadata = attention_metadata_gid
1629
+ else:
1630
+ for layer_name in kv_cache_group.layer_names:
1631
+ attention_metadata_per_layer[
1632
+ layer_name] = attention_metadata_gid
1536
1633
 
1537
1634
  if self.scheduler_config.async_scheduling and len(
1538
1635
  token_in_tpu_cur_input_indices) > 0:
@@ -1545,20 +1642,15 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1545
1642
  self.lora_utils.set_active_loras(
1546
1643
  num_scheduled_tokens_per_req, total_num_scheduled_tokens,
1547
1644
  padded_total_num_scheduled_tokens)
1548
-
1549
- attention_metadata = AttentionMetadata(
1550
- input_positions=positions,
1551
- block_tables=block_tables,
1552
- seq_lens=seq_lens,
1553
- query_start_loc=query_start_loc,
1554
- request_distribution=request_distribution)
1555
-
1556
- # This is for making these cpu buffers hidden during tracing
1557
- attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1558
- attention_metadata.seq_lens_cpu = seq_lens_cpu
1559
1645
  logits_indices_selector = None
1560
- return (input_ids, attention_metadata, sampling_metadata,
1561
- logits_indices, spec_decode_metadata, logits_indices_selector)
1646
+
1647
+ if self.use_hybrid_kvcache:
1648
+ attention_metadata = attention_metadata_per_layer
1649
+ else:
1650
+ attention_metadata = uniform_attention_metadata
1651
+ return (input_ids, positions, attention_metadata, sampling_metadata,
1652
+ logits_indices, spec_decode_metadata, logits_indices_selector,
1653
+ padded_num_reqs)
1562
1654
 
1563
1655
  def _get_input_ids_embeds(self, input_ids: jax.Array,
1564
1656
  mm_embeds: list[jax.Array]):
@@ -1618,3 +1710,35 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1618
1710
  mappings=mappings,
1619
1711
  transpose_keys=transpose_keys,
1620
1712
  shard=shard)
1713
+
1714
+ def get_intermediate_tensor_spec(self, num_tokens: int):
1715
+ impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
1716
+ jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
1717
+ num_padded_tokens = runner_utils.get_padded_token_len(
1718
+ self.num_tokens_paddings, num_tokens)
1719
+ sharding = NamedSharding(self.mesh, PartitionSpec())
1720
+ hidden_size = self.model_config.get_hidden_size()
1721
+ spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
1722
+ dtype=jax_dtype,
1723
+ sharding=sharding)
1724
+ tensor_spec = {"hidden_states": spec, "residual": spec}
1725
+ return tensor_spec
1726
+
1727
+ def get_uuid_for_jax_transfer(self,
1728
+ scheduler_output: "VllmSchedulerOutput",
1729
+ rank: int, step: int) -> int:
1730
+ '''
1731
+ Get a uuid for jax.transfer, here we use the hash of
1732
+ scheduler_output + counter_step + sender's rank
1733
+ '''
1734
+ scheduler_output_str = ""
1735
+ if not scheduler_output.num_scheduled_tokens:
1736
+ scheduler_output_str = "empty_batch"
1737
+ else:
1738
+ scheduler_output_str = str(
1739
+ sorted(scheduler_output.num_scheduled_tokens.items()))
1740
+ unique_str = f'{scheduler_output_str} {step} {rank}'
1741
+ import hashlib
1742
+ hasher = hashlib.sha1()
1743
+ hasher.update(unique_str.encode('utf-8'))
1744
+ return int.from_bytes(hasher.digest()[:8], 'big')