tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_envs.py +11 -32
  6. tests/test_utils.py +2 -1
  7. tpu_inference/__init__.py +3 -22
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +2 -9
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
  16. tpu_inference/layers/common/attention_interface.py +1 -7
  17. tpu_inference/layers/common/sharding.py +5 -5
  18. tpu_inference/layers/vllm/fused_moe.py +208 -170
  19. tpu_inference/layers/vllm/quantization/common.py +1 -6
  20. tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
  21. tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
  22. tpu_inference/layers/vllm/sharding.py +2 -2
  23. tpu_inference/lora/torch_punica_tpu.py +2 -1
  24. tpu_inference/mock/__init__.py +0 -0
  25. tpu_inference/mock/vllm_config_utils.py +28 -0
  26. tpu_inference/mock/vllm_envs.py +1219 -0
  27. tpu_inference/mock/vllm_logger.py +212 -0
  28. tpu_inference/mock/vllm_logging_utils.py +15 -0
  29. tpu_inference/models/common/model_loader.py +10 -43
  30. tpu_inference/models/jax/llama3.py +1 -2
  31. tpu_inference/models/jax/llama_eagle3.py +5 -8
  32. tpu_inference/models/jax/phi3.py +376 -0
  33. tpu_inference/models/jax/qwen2.py +1 -2
  34. tpu_inference/models/jax/qwen2_5_vl.py +48 -163
  35. tpu_inference/models/jax/qwen3.py +1 -2
  36. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  37. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  38. tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
  39. tpu_inference/platforms/tpu_platform.py +31 -37
  40. tpu_inference/runner/compilation_manager.py +58 -141
  41. tpu_inference/runner/kv_cache.py +1 -1
  42. tpu_inference/runner/kv_cache_manager.py +18 -17
  43. tpu_inference/runner/persistent_batch_manager.py +2 -40
  44. tpu_inference/runner/structured_decoding_manager.py +3 -2
  45. tpu_inference/runner/tpu_runner.py +147 -271
  46. tpu_inference/runner/utils.py +2 -2
  47. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  48. tpu_inference/tpu_info.py +3 -4
  49. tpu_inference/utils.py +13 -36
  50. tpu_inference/worker/tpu_worker.py +25 -162
  51. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
  52. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
  53. tpu_inference/models/jax/llama_guard_4.py +0 -361
  54. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  55. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  56. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
@@ -10,16 +10,17 @@ import jax
10
10
  import jax.numpy as jnp
11
11
  import jaxtyping
12
12
  import numpy as np
13
- import vllm.envs as vllm_envs
13
+ import torch
14
+ import vllm.envs as envs
14
15
  from flax import nnx
15
16
  from jax.experimental import mesh_utils
16
17
  from jax.sharding import NamedSharding, PartitionSpec
17
- from torchax.ops.mappings import t2j_dtype
18
+ from torchax.ops.mappings import j2t_dtype
18
19
  from vllm.config import VllmConfig
19
- from vllm.distributed import get_pp_group
20
20
  from vllm.distributed.kv_transfer import (get_kv_transfer_group,
21
21
  has_kv_transfer_group)
22
22
  from vllm.forward_context import set_forward_context
23
+ from vllm.sequence import IntermediateTensors
23
24
  from vllm.tasks import SupportedTask
24
25
  from vllm.utils.math_utils import cdiv
25
26
  from vllm.v1.core.sched.output import GrammarOutput
@@ -34,7 +35,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
34
35
  KVConnectorModelRunnerMixin
35
36
  from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
36
37
 
37
- import tpu_inference.envs as envs
38
38
  from tpu_inference import utils as common_utils
39
39
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
40
  from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
@@ -48,8 +48,6 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
48
48
  TPUSupportedSamplingMetadata
49
49
  from tpu_inference.logger import init_logger
50
50
  from tpu_inference.models.common.model_loader import get_model
51
- from tpu_inference.models.jax.jax_intermediate_tensor import \
52
- JaxIntermediateTensors
53
51
  from tpu_inference.models.jax.utils.weight_utils import (
54
52
  shard_put, transfer_state_with_mappings)
55
53
  from tpu_inference.runner import utils as runner_utils
@@ -66,7 +64,7 @@ from tpu_inference.runner.structured_decoding_manager import \
66
64
  StructuredDecodingManager
67
65
  from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
68
66
  from tpu_inference.utils import (device_array, make_optimized_mesh,
69
- time_function, to_torch_dtype)
67
+ time_function)
70
68
 
71
69
  logger = init_logger(__name__)
72
70
 
@@ -80,6 +78,17 @@ DUMMY_METADATA = AttentionMetadata(
80
78
  request_distribution=[0, 0, 0],
81
79
  )
82
80
 
81
+ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
82
+ "half": torch.half,
83
+ "bfloat16": torch.bfloat16,
84
+ "float": torch.float,
85
+ "fp8": torch.float8_e4m3fn,
86
+ "fp8_e4m3": torch.float8_e4m3fn,
87
+ "fp8_e5m2": torch.float8_e5m2,
88
+ "int8": torch.int8,
89
+ "uint8": torch.uint8,
90
+ }
91
+
83
92
 
84
93
  class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
85
94
  """Holds asynchronous model output specifically from a TPU runner.
@@ -144,7 +153,6 @@ class ExecuteModelState:
144
153
  spec_decode_metadata: Optional[SpecDecodeMetadata]
145
154
  kv_connector_output: Optional[KVConnectorOutput]
146
155
  logits_indices_selector: Optional[List[int]] = None
147
- padded_num_reqs: Optional[int] = None
148
156
 
149
157
 
150
158
  @functools.partial(jax.jit, donate_argnums=(0, 1, 2))
@@ -182,28 +190,18 @@ def _substitute_placeholder_token(
182
190
  return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
183
191
 
184
192
 
185
- def _jax_logprobs_to_lists(logprobs_tensors,
186
- logits_indices_selector=None,
187
- cu_num_generated_tokens=None):
188
- """Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
189
- log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
190
- logprobs_list = logprobs_tensors.logprobs.tolist()
191
- selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
192
-
193
- if logits_indices_selector is not None:
194
- log_token_ids_list = [
195
- log_token_ids_list[i] for i in logits_indices_selector
196
- ]
197
- logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
198
- selected_token_ranks_list = [
199
- selected_token_ranks_list[i] for i in logits_indices_selector
200
- ]
201
-
193
+ def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
202
194
  return LogprobsLists(
203
- logprob_token_ids=np.asarray(log_token_ids_list),
204
- logprobs=np.asarray(logprobs_list),
205
- sampled_token_ranks=np.asarray(selected_token_ranks_list),
206
- cu_num_generated_tokens=cu_num_generated_tokens,
195
+ logprob_token_ids=[
196
+ logprobs_lists.logprob_token_ids[i]
197
+ for i in logits_indices_selector
198
+ ],
199
+ logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
200
+ sampled_token_ranks=[
201
+ logprobs_lists.sampled_token_ranks[i]
202
+ for i in logits_indices_selector
203
+ ],
204
+ cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
207
205
  )
208
206
 
209
207
 
@@ -213,9 +211,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
213
211
  self,
214
212
  vllm_config: VllmConfig,
215
213
  devices: List[Any],
216
- rank: int = 0,
217
- is_first_rank: bool = True,
218
- is_last_rank: bool = True,
219
214
  ):
220
215
  self.vllm_config = vllm_config
221
216
  self.model_config = vllm_config.model_config
@@ -234,9 +229,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
234
229
  self.maybe_forbid_compile = runner_utils.ForbidCompile(
235
230
  ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
236
231
  self.dp_size = self.vllm_config.sharding_config.total_dp_size
237
- self.rank = rank
238
- self.is_first_rank = is_first_rank
239
- self.is_last_rank = is_last_rank
240
232
 
241
233
  self._init_random()
242
234
  self._init_mesh()
@@ -247,21 +239,31 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
247
239
 
248
240
  # Delegate functions to specific manager classes.
249
241
  self.compilation_manager = CompilationManager(self)
250
- if self.is_last_rank:
251
- self.speculative_decoding_manager = SpeculativeDecodingManager(
252
- self)
253
- self.structured_decoding_manager = StructuredDecodingManager(self)
242
+ self.speculative_decoding_manager = SpeculativeDecodingManager(self)
243
+ self.structured_decoding_manager = StructuredDecodingManager(self)
254
244
  self.kv_cache_manager = KVCacheManager(self)
255
245
  self.mm_manager = MultiModalManager(self)
256
246
  self.persistent_batch_manager = PersistentBatchManager(
257
247
  self.requests, self.input_batch, self.encoder_cache,
258
- self.uses_mrope, self.model_config, self.is_last_rank)
248
+ self.uses_mrope, self.model_config)
259
249
  self.lora_utils = LoraUtils(self)
260
250
 
261
- cache_dtype = self.cache_config.cache_dtype
262
- if cache_dtype == "auto":
263
- cache_dtype = self.dtype
264
- self.kv_cache_dtype = to_torch_dtype(cache_dtype)
251
+ cache_config = self.cache_config
252
+ if cache_config.cache_dtype == "auto":
253
+ model_dtype = self.dtype
254
+ if isinstance(model_dtype, str):
255
+ self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
256
+ elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
257
+ self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
258
+ elif isinstance(model_dtype, torch.dtype):
259
+ self.kv_cache_dtype = model_dtype
260
+ else:
261
+ raise ValueError(
262
+ "KV cache is unsupported for model_dtype of %s",
263
+ model_dtype)
264
+ else:
265
+ self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
266
+ cache_config.cache_dtype]
265
267
 
266
268
  self._pre_async_results: AsyncPreResults | None = None
267
269
  self._substitute_placeholder_token_fn = _substitute_placeholder_token
@@ -275,7 +277,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
275
277
  self.rng_key = jax.random.key(self.model_config.seed)
276
278
 
277
279
  def _init_mesh(self) -> None:
278
- if envs.NEW_MODEL_DESIGN:
280
+ if os.getenv("NEW_MODEL_DESIGN", False):
279
281
  self.mesh = self._create_new_model_mesh()
280
282
  else:
281
283
  # NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
@@ -286,7 +288,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
286
288
  logger.info(f"Init mesh | mesh={self.mesh}")
287
289
 
288
290
  def _create_new_model_mesh(self) -> jax.sharding.Mesh:
289
- num_slices = envs.NUM_SLICES
291
+ num_slices = int(os.environ.get('NUM_SLICES', 1))
290
292
 
291
293
  logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
292
294
  f"num_slices={num_slices}")
@@ -355,7 +357,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
355
357
  devices=self.devices)
356
358
 
357
359
  def _init_phased_profiling(self) -> None:
358
- self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
360
+ self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
359
361
  self.phase_based_profiler = None
360
362
  if self.phased_profiling_dir:
361
363
  self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
@@ -397,7 +399,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
397
399
  min_token_size=max(16, self.dp_size),
398
400
  max_token_size=scheduler_config.max_num_batched_tokens *
399
401
  self.dp_size,
400
- padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
402
+ padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
401
403
  self.num_tokens_paddings_per_dp = [
402
404
  padding // self.dp_size for padding in self.num_tokens_paddings
403
405
  ]
@@ -421,14 +423,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
421
423
 
422
424
  self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
423
425
  self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
424
- # Note: self.input_batch and self.block_tables_cpu are both initialized
425
- # with only 1 block_size. For hybrid kv cache, it will be re-init
426
- # in kv_cache_manager's maybe_reinitialize_input_batch.
427
- self.block_tables_cpu = [
428
- np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
429
- dtype=np.int32)
430
- ]
431
-
426
+ self.block_table_cpu = np.zeros(
427
+ (self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
432
428
  self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
433
429
  dtype=np.int32)
434
430
  self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
@@ -462,6 +458,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
462
458
 
463
459
  # tensors for structured decoding
464
460
  self.vocab_size = self.model_config.get_vocab_size()
461
+ if self.lora_config is not None:
462
+ # lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
463
+ self.vocab_size += self.lora_config.lora_extra_vocab_size
465
464
  self.grammar_bitmask_cpu = np.zeros(
466
465
  (self.max_num_reqs, cdiv(self.vocab_size, 32)),
467
466
  dtype=np.int32,
@@ -506,14 +505,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
506
505
 
507
506
  self.rng_params_for_sampling = nnx.Rngs(
508
507
  jax.random.key(self.model_config.seed)).params()
509
- self.is_multimodal_model = (
510
- self.model_config.is_multimodal_model
511
- and self.get_multimodal_embeddings_fn is not None and hasattr(
512
- self.model_config.hf_config, "architectures"
513
- ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
514
- and len(self.model_config.hf_config.architectures) >= 1
515
- and self.model_config.hf_config.architectures[0]
516
- != "Llama4ForConditionalGeneration")
508
+ self.is_multimodal_model = (self.model_config.is_multimodal_model
509
+ and self.get_multimodal_embeddings_fn
510
+ is not None)
517
511
 
518
512
  logger.info(f"Init model | "
519
513
  f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
@@ -526,7 +520,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
526
520
 
527
521
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
528
522
  self.kv_cache_config = kv_cache_config
529
- self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
530
523
  self.kv_caches = []
531
524
  self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
532
525
  if has_kv_transfer_group():
@@ -539,12 +532,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
539
532
  def execute_model(
540
533
  self,
541
534
  scheduler_output: "VllmSchedulerOutput",
542
- intermediate_tensors: Optional[JaxIntermediateTensors] = None,
543
- ) -> ModelRunnerOutput | JaxIntermediateTensors | None:
535
+ intermediate_tensors: Optional[IntermediateTensors] = None,
536
+ ) -> ModelRunnerOutput | None:
544
537
  if self.execute_model_state is not None:
545
538
  raise RuntimeError("State error: sample_tokens() must be called "
546
539
  "after execute_model() returns None.")
547
- _, output = self._execute_model(scheduler_output, intermediate_tensors)
540
+ _, output = self._execute_model(scheduler_output)
548
541
  return output
549
542
 
550
543
  def sample_tokens(
@@ -557,17 +550,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
557
550
 
558
551
  (scheduler_output, attn_metadata, input_ids, hidden_states, logits,
559
552
  aux_hidden_states, spec_decode_metadata, kv_connector_output,
560
- logits_indices_selector,
561
- padded_num_reqs) = (self.execute_model_state.scheduler_output,
562
- self.execute_model_state.attn_metadata,
563
- self.execute_model_state.input_ids,
564
- self.execute_model_state.hidden_states,
565
- self.execute_model_state.logits,
566
- self.execute_model_state.aux_hidden_states,
567
- self.execute_model_state.spec_decode_metadata,
568
- self.execute_model_state.kv_connector_output,
569
- self.execute_model_state.logits_indices_selector,
570
- self.execute_model_state.padded_num_reqs)
553
+ logits_indices_selector) = (
554
+ self.execute_model_state.scheduler_output,
555
+ self.execute_model_state.attn_metadata,
556
+ self.execute_model_state.input_ids,
557
+ self.execute_model_state.hidden_states,
558
+ self.execute_model_state.logits,
559
+ self.execute_model_state.aux_hidden_states,
560
+ self.execute_model_state.spec_decode_metadata,
561
+ self.execute_model_state.kv_connector_output,
562
+ self.execute_model_state.logits_indices_selector)
571
563
  self.execute_model_state = None
572
564
 
573
565
  if grammar_output is not None:
@@ -581,10 +573,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
581
573
  logits,
582
574
  arange,
583
575
  )
584
- return self._sample_from_logits(
585
- scheduler_output, attn_metadata, input_ids, hidden_states, logits,
586
- aux_hidden_states, spec_decode_metadata, kv_connector_output,
587
- logits_indices_selector, padded_num_reqs)
576
+ return self._sample_from_logits(scheduler_output, attn_metadata,
577
+ input_ids, hidden_states, logits,
578
+ aux_hidden_states,
579
+ spec_decode_metadata,
580
+ kv_connector_output,
581
+ logits_indices_selector)
588
582
 
589
583
  def _modify_prev_results(self):
590
584
  # If copy to host has not been done, we just wait.
@@ -670,9 +664,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
670
664
  def _execute_model(
671
665
  self,
672
666
  scheduler_output: "VllmSchedulerOutput",
673
- intermediate_tensors: Optional[JaxIntermediateTensors] = None,
674
- ) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
675
- | None]:
667
+ ) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
676
668
  self.persistent_batch_manager.update_states(
677
669
  scheduler_output, self.get_mrope_input_positions_fn)
678
670
  if not scheduler_output.total_num_scheduled_tokens:
@@ -695,23 +687,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
695
687
  # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
696
688
  (
697
689
  input_ids,
698
- input_positions,
699
690
  attn_metadata,
700
691
  _,
701
692
  logits_indices,
702
693
  spec_decode_metadata,
703
694
  logits_indices_selector,
704
- padded_num_reqs,
705
695
  ) = self._prepare_inputs(scheduler_output)
706
696
 
707
- is_llama_guard_4 = (
708
- hasattr(
709
- self.model_config.hf_config, "architectures"
710
- ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
711
- and len(self.model_config.hf_config.architectures) >= 1
712
- and self.model_config.hf_config.architectures[0]
713
- == "Llama4ForConditionalGeneration")
714
-
715
697
  # multi-modal support
716
698
  if self.is_multimodal_model:
717
699
  # Run the multimodal encoder if any.
@@ -719,13 +701,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
719
701
  self.mm_manager.execute_mm_encoder(scheduler_output)
720
702
  mm_embeds = self.mm_manager.gather_mm_embeddings(
721
703
  scheduler_output, input_ids.shape[0])
722
- #TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
723
- elif is_llama_guard_4 and any(
724
- self.mm_manager.runner.requests[req_id].mm_features
725
- for req_id in self.mm_manager.runner.input_batch.req_ids):
726
- raise NotImplementedError(
727
- "Llama Guard 4 (JAX) currently supports only text inputs. "
728
- "Multimodal processing not yet implemented.")
729
704
  else:
730
705
  mm_embeds = []
731
706
 
@@ -750,6 +725,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
750
725
  scheduler_output) as kv_connector_output:
751
726
  # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
752
727
  # but one of them would be `None`
728
+
753
729
  (self.kv_caches, hidden_states,
754
730
  aux_hidden_states) = self.model_fn(
755
731
  self.state,
@@ -757,17 +733,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
757
733
  input_ids,
758
734
  attn_metadata,
759
735
  inputs_embeds,
760
- input_positions,
761
736
  tuple(self.layer_name_to_kvcache_index.items()),
762
737
  lora_metadata,
763
- intermediate_tensors,
764
- self.is_first_rank,
765
- self.is_last_rank,
766
738
  )
767
- if not get_pp_group().is_last_rank:
768
- assert isinstance(hidden_states, JaxIntermediateTensors)
769
- hidden_states.kv_connector_output = kv_connector_output
770
- return attn_metadata, hidden_states
739
+
771
740
  hidden_states = self._select_from_array_fn(hidden_states,
772
741
  logits_indices)
773
742
  logits = self.compute_logits_fn(
@@ -785,8 +754,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
785
754
  aux_hidden_states=aux_hidden_states,
786
755
  spec_decode_metadata=spec_decode_metadata,
787
756
  kv_connector_output=kv_connector_output,
788
- logits_indices_selector=logits_indices_selector,
789
- padded_num_reqs=padded_num_reqs)
757
+ logits_indices_selector=logits_indices_selector)
790
758
  return attn_metadata, None
791
759
 
792
760
  def _sample_from_logits(
@@ -800,44 +768,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
800
768
  spec_decode_metadata: Optional[SpecDecodeMetadata],
801
769
  kv_connector_output: Optional[KVConnectorOutput],
802
770
  logits_indices_selector: Optional[List[int]] = None,
803
- padded_num_reqs: Optional[int] = None,
804
771
  ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
805
- if padded_num_reqs is None:
806
- padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
807
- self.input_batch.num_reqs, self.max_num_reqs)
808
-
809
- sharding = None
810
- if self.dp_size > 1:
811
- sharding = NamedSharding(self.mesh,
812
- PartitionSpec(ShardingAxisName.ATTN_DATA))
813
-
772
+ padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
773
+ self.input_batch.num_reqs, self.max_num_reqs)
814
774
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
815
- self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
816
-
817
- # TODO(pooyam): Should we move this to `_prepare_inputs`?
818
- if tpu_sampling_metadata.do_sampling:
819
- self.rng_params_for_sampling, step_rng = jax.random.split(
820
- self.rng_params_for_sampling)
821
- else:
822
- step_rng = self.rng_params_for_sampling
823
-
775
+ self.mesh, self.input_batch, padded_num_reqs)
824
776
  if spec_decode_metadata is None:
825
777
  next_tokens = sample(
826
- step_rng,
778
+ self.rng_params_for_sampling,
827
779
  self.mesh,
828
780
  logits,
829
781
  tpu_sampling_metadata,
830
782
  )
831
783
  else:
832
- if tpu_sampling_metadata.do_sampling:
833
- bonus_rng, rejection_rng = jax.random.split(step_rng)
834
- else:
835
- bonus_rng = step_rng
836
- rejection_rng = step_rng
837
784
  bonus_logits = self._select_from_array_fn(
838
785
  logits, spec_decode_metadata.bonus_logits_indices)
839
786
  bonus_token_ids = sample(
840
- bonus_rng,
787
+ self.rng_params_for_sampling,
841
788
  self.mesh,
842
789
  bonus_logits,
843
790
  tpu_sampling_metadata,
@@ -851,7 +798,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
851
798
  target_logits=target_logits,
852
799
  bonus_token_ids=bonus_token_ids,
853
800
  sampling_metadata=tpu_sampling_metadata,
854
- key=rejection_rng,
801
+ key=self.rng_params_for_sampling,
855
802
  )
856
803
 
857
804
  if tpu_sampling_metadata.logprobs:
@@ -909,8 +856,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
909
856
 
910
857
  if logprobs is not None:
911
858
  # Map logprobs back to the pre-dp shuffling order
912
- logprobs_lists = _jax_logprobs_to_lists(
913
- logprobs, logits_indices_selector)
859
+ logprobs_lists = logprobs.tolists()
860
+ if logits_indices_selector is not None:
861
+ logprobs_lists = _reorder_logits_indices(
862
+ logprobs_lists, logits_indices_selector)
914
863
 
915
864
  else:
916
865
  logprobs_lists = None
@@ -980,8 +929,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
980
929
 
981
930
  if logprobs is not None:
982
931
  # Map logprobs back to the pre-dp shuffling order
983
- logprobs_lists = _jax_logprobs_to_lists(logprobs,
984
- logits_indices_selector)
932
+ logprobs_lists = logprobs.tolists()
933
+ if logits_indices_selector is not None:
934
+ logprobs_lists = _reorder_logits_indices(
935
+ logprobs_lists, logits_indices_selector)
985
936
  else:
986
937
  logprobs_lists = None
987
938
 
@@ -1329,6 +1280,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1329
1280
  mrope_positions = self.mrope_positions_cpu[:, :
1330
1281
  padded_total_num_scheduled_tokens]
1331
1282
 
1283
+ block_tables = self.block_table_cpu[:self.max_num_reqs]
1284
+ for dp_rank in range(dp_size):
1285
+ req_offset = dp_rank * max_num_reqs_per_dp_rank
1286
+ _num_reqs = num_req_per_dp_rank[dp_rank]
1287
+
1288
+ block_tables[
1289
+ req_offset:req_offset + _num_reqs, :self.
1290
+ max_num_blocks_per_req] = self.input_batch.block_table[
1291
+ 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1292
+
1332
1293
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
1333
1294
  dp_size]
1334
1295
  seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
@@ -1370,59 +1331,20 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1370
1331
  if self.uses_mrope:
1371
1332
  positions = mrope_positions
1372
1333
 
1334
+ # Convert block_tables to 1D on cpu.
1335
+ block_tables = block_tables.reshape(-1)
1336
+
1373
1337
  query_start_loc_cpu = query_start_loc
1374
1338
  logits_indices_cpu = logits_indices
1375
1339
  seq_lens_cpu = seq_lens
1376
1340
 
1377
- (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1378
- request_distribution) = device_array(
1341
+ (input_ids, positions, block_tables, query_start_loc, seq_lens,
1342
+ logits_indices, request_distribution) = device_array(
1379
1343
  self.mesh,
1380
- (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1381
- request_distribution),
1344
+ (input_ids, positions, block_tables, query_start_loc, seq_lens,
1345
+ logits_indices, request_distribution),
1382
1346
  sharding=data_parallel_attn_sharding,
1383
1347
  )
1384
-
1385
- attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1386
- uniform_attention_metadata: AttentionMetadata = None
1387
- for kv_cache_gid, kv_cache_group in enumerate(
1388
- self.kv_cache_config.kv_cache_groups):
1389
- block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1390
- max_num_reqs]
1391
- for dp_rank in range(dp_size):
1392
- req_offset = dp_rank * max_num_reqs_per_dp_rank
1393
- _num_reqs = num_req_per_dp_rank[dp_rank]
1394
-
1395
- block_tables[
1396
- req_offset:req_offset + _num_reqs, :self.
1397
- max_num_blocks_per_req] = self.input_batch.block_table[
1398
- 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1399
- # Convert block_tables to 1D on cpu.
1400
- block_tables = block_tables.reshape(-1)
1401
- block_tables = device_array(
1402
- self.mesh,
1403
- (block_tables),
1404
- sharding=data_parallel_attn_sharding,
1405
- )
1406
-
1407
- attention_metadata_gid = AttentionMetadata(
1408
- input_positions=positions,
1409
- block_tables=block_tables,
1410
- seq_lens=seq_lens,
1411
- query_start_loc=query_start_loc,
1412
- request_distribution=request_distribution,
1413
- )
1414
-
1415
- # This is for making these cpu buffers hidden during tracing
1416
- attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1417
- attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1418
-
1419
- if not self.use_hybrid_kvcache:
1420
- uniform_attention_metadata = attention_metadata_gid
1421
- else:
1422
- for layer_name in kv_cache_group.layer_names:
1423
- attention_metadata_per_layer[
1424
- layer_name] = attention_metadata_gid
1425
-
1426
1348
  # Async scheduling: substitute placeholder tokens for DP
1427
1349
  if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
1428
1350
  # Collect all token indices that need substitution across all DP ranks
@@ -1451,19 +1373,25 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1451
1373
  padded_total_num_scheduled_tokens,
1452
1374
  )
1453
1375
 
1454
- if self.use_hybrid_kvcache:
1455
- attention_metadata = attention_metadata_per_layer
1456
- else:
1457
- attention_metadata = uniform_attention_metadata
1376
+ attention_metadata = AttentionMetadata(
1377
+ input_positions=positions,
1378
+ block_tables=block_tables,
1379
+ seq_lens=seq_lens,
1380
+ query_start_loc=query_start_loc,
1381
+ request_distribution=request_distribution,
1382
+ )
1383
+
1384
+ # This is for making these cpu buffers hidden during tracing
1385
+ attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1386
+ attention_metadata.seq_lens_cpu = seq_lens_cpu
1387
+
1458
1388
  return (
1459
1389
  input_ids,
1460
- positions,
1461
1390
  attention_metadata,
1462
1391
  sampling_metadata,
1463
1392
  logits_indices,
1464
1393
  spec_decode_metadata,
1465
1394
  logits_indices_selector,
1466
- padded_num_reqs,
1467
1395
  )
1468
1396
 
1469
1397
  def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
@@ -1564,6 +1492,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1564
1492
  positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
1565
1493
  mrope_positions = self.mrope_positions_cpu[:, :
1566
1494
  padded_total_num_scheduled_tokens]
1495
+ block_tables = self.block_table_cpu[:self.max_num_reqs]
1496
+ block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
1497
+ self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
1567
1498
 
1568
1499
  # TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
1569
1500
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
@@ -1592,44 +1523,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1592
1523
  self.mesh, self.input_batch, padded_num_reqs)
1593
1524
  if self.uses_mrope:
1594
1525
  positions = mrope_positions
1526
+
1527
+ # Convert block_tables to 1D on cpu.
1528
+ block_tables = block_tables.reshape(-1)
1529
+
1595
1530
  query_start_loc_cpu = query_start_loc
1596
1531
  seq_lens_cpu = seq_lens
1597
-
1598
- (input_ids, positions, query_start_loc, seq_lens,
1532
+ (input_ids, positions, block_tables, query_start_loc, seq_lens,
1599
1533
  logits_indices, request_distribution) = device_array(
1600
- self.mesh, (input_ids, positions, query_start_loc, seq_lens,
1601
- logits_indices, request_distribution))
1602
-
1603
- attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1604
- uniform_attention_metadata: AttentionMetadata = None
1605
- for kv_cache_gid, kv_cache_group in enumerate(
1606
- self.kv_cache_config.kv_cache_groups):
1607
- block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1608
- max_num_reqs]
1609
- block_tables[:num_reqs] = (
1610
- self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
1611
- [:num_reqs])
1612
- # Convert block_tables to 1D on cpu.
1613
- block_tables = block_tables.reshape(-1)
1614
- block_tables = device_array(self.mesh, (block_tables))
1615
-
1616
- attention_metadata_gid = AttentionMetadata(
1617
- input_positions=positions,
1618
- block_tables=block_tables,
1619
- seq_lens=seq_lens,
1620
- query_start_loc=query_start_loc,
1621
- request_distribution=request_distribution)
1622
- # This is for making these cpu buffers hidden during tracing
1623
- attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1624
- attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1625
-
1626
- if not self.use_hybrid_kvcache:
1627
- # all layers share the same attention metadata
1628
- uniform_attention_metadata = attention_metadata_gid
1629
- else:
1630
- for layer_name in kv_cache_group.layer_names:
1631
- attention_metadata_per_layer[
1632
- layer_name] = attention_metadata_gid
1534
+ self.mesh, (input_ids, positions, block_tables, query_start_loc,
1535
+ seq_lens, logits_indices, request_distribution))
1633
1536
 
1634
1537
  if self.scheduler_config.async_scheduling and len(
1635
1538
  token_in_tpu_cur_input_indices) > 0:
@@ -1642,15 +1545,20 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1642
1545
  self.lora_utils.set_active_loras(
1643
1546
  num_scheduled_tokens_per_req, total_num_scheduled_tokens,
1644
1547
  padded_total_num_scheduled_tokens)
1645
- logits_indices_selector = None
1646
1548
 
1647
- if self.use_hybrid_kvcache:
1648
- attention_metadata = attention_metadata_per_layer
1649
- else:
1650
- attention_metadata = uniform_attention_metadata
1651
- return (input_ids, positions, attention_metadata, sampling_metadata,
1652
- logits_indices, spec_decode_metadata, logits_indices_selector,
1653
- padded_num_reqs)
1549
+ attention_metadata = AttentionMetadata(
1550
+ input_positions=positions,
1551
+ block_tables=block_tables,
1552
+ seq_lens=seq_lens,
1553
+ query_start_loc=query_start_loc,
1554
+ request_distribution=request_distribution)
1555
+
1556
+ # This is for making these cpu buffers hidden during tracing
1557
+ attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1558
+ attention_metadata.seq_lens_cpu = seq_lens_cpu
1559
+ logits_indices_selector = None
1560
+ return (input_ids, attention_metadata, sampling_metadata,
1561
+ logits_indices, spec_decode_metadata, logits_indices_selector)
1654
1562
 
1655
1563
  def _get_input_ids_embeds(self, input_ids: jax.Array,
1656
1564
  mm_embeds: list[jax.Array]):
@@ -1710,35 +1618,3 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1710
1618
  mappings=mappings,
1711
1619
  transpose_keys=transpose_keys,
1712
1620
  shard=shard)
1713
-
1714
- def get_intermediate_tensor_spec(self, num_tokens: int):
1715
- impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
1716
- jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
1717
- num_padded_tokens = runner_utils.get_padded_token_len(
1718
- self.num_tokens_paddings, num_tokens)
1719
- sharding = NamedSharding(self.mesh, PartitionSpec())
1720
- hidden_size = self.model_config.get_hidden_size()
1721
- spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
1722
- dtype=jax_dtype,
1723
- sharding=sharding)
1724
- tensor_spec = {"hidden_states": spec, "residual": spec}
1725
- return tensor_spec
1726
-
1727
- def get_uuid_for_jax_transfer(self,
1728
- scheduler_output: "VllmSchedulerOutput",
1729
- rank: int, step: int) -> int:
1730
- '''
1731
- Get a uuid for jax.transfer, here we use the hash of
1732
- scheduler_output + counter_step + sender's rank
1733
- '''
1734
- scheduler_output_str = ""
1735
- if not scheduler_output.num_scheduled_tokens:
1736
- scheduler_output_str = "empty_batch"
1737
- else:
1738
- scheduler_output_str = str(
1739
- sorted(scheduler_output.num_scheduled_tokens.items()))
1740
- unique_str = f'{scheduler_output_str} {step} {rank}'
1741
- import hashlib
1742
- hasher = hashlib.sha1()
1743
- hasher.update(unique_str.encode('utf-8'))
1744
- return int.from_bytes(hasher.digest()[:8], 'big')