tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -10,23 +10,24 @@ import jax
10
10
  import jax.numpy as jnp
11
11
  import jaxtyping
12
12
  import numpy as np
13
- import vllm.envs as vllm_envs
13
+ import torch
14
+ import vllm.envs as envs
14
15
  from flax import nnx
15
16
  from jax.experimental import mesh_utils
16
17
  from jax.sharding import NamedSharding, PartitionSpec
17
- from torchax.ops.mappings import t2j_dtype
18
+ from torchax.ops.mappings import j2t_dtype
18
19
  from vllm.config import VllmConfig
19
- from vllm.distributed import get_pp_group
20
20
  from vllm.distributed.kv_transfer import (get_kv_transfer_group,
21
21
  has_kv_transfer_group)
22
22
  from vllm.forward_context import set_forward_context
23
+ from vllm.sequence import IntermediateTensors
23
24
  from vllm.tasks import SupportedTask
24
25
  from vllm.utils.math_utils import cdiv
25
26
  from vllm.v1.core.sched.output import GrammarOutput
26
27
  from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
27
28
  from vllm.v1.kv_cache_interface import KVCacheConfig
28
29
  from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
29
- DraftTokenIds, KVConnectorOutput, LogprobsLists,
30
+ DraftTokenIds, KVConnectorOutput,
30
31
  ModelRunnerOutput)
31
32
  from vllm.v1.request import Request
32
33
  from vllm.v1.spec_decode.ngram_proposer import NgramProposer
@@ -34,22 +35,19 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
34
35
  KVConnectorModelRunnerMixin
35
36
  from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
36
37
 
37
- import tpu_inference.envs as envs
38
38
  from tpu_inference import utils as common_utils
39
39
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
- from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
41
- MESH_AXIS_NAMES_2D,
42
- ShardingAxisName,
43
- ShardingConfigManager)
44
40
  from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
45
41
  from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
46
42
  gather_logprobs, sample)
47
43
  from tpu_inference.layers.jax.sample.sampling_metadata import \
48
44
  TPUSupportedSamplingMetadata
45
+ from tpu_inference.layers.jax.sharding import (MESH_AXIS_NAMES,
46
+ MESH_AXIS_NAMES_2D,
47
+ ShardingAxisName,
48
+ ShardingConfigManager)
49
49
  from tpu_inference.logger import init_logger
50
50
  from tpu_inference.models.common.model_loader import get_model
51
- from tpu_inference.models.jax.jax_intermediate_tensor import \
52
- JaxIntermediateTensors
53
51
  from tpu_inference.models.jax.utils.weight_utils import (
54
52
  shard_put, transfer_state_with_mappings)
55
53
  from tpu_inference.runner import utils as runner_utils
@@ -66,7 +64,7 @@ from tpu_inference.runner.structured_decoding_manager import \
66
64
  StructuredDecodingManager
67
65
  from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
68
66
  from tpu_inference.utils import (device_array, make_optimized_mesh,
69
- time_function, to_torch_dtype)
67
+ time_function)
70
68
 
71
69
  logger = init_logger(__name__)
72
70
 
@@ -80,6 +78,17 @@ DUMMY_METADATA = AttentionMetadata(
80
78
  request_distribution=[0, 0, 0],
81
79
  )
82
80
 
81
+ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
82
+ "half": torch.half,
83
+ "bfloat16": torch.bfloat16,
84
+ "float": torch.float,
85
+ "fp8": torch.float8_e4m3fn,
86
+ "fp8_e4m3": torch.float8_e4m3fn,
87
+ "fp8_e5m2": torch.float8_e5m2,
88
+ "int8": torch.int8,
89
+ "uint8": torch.uint8,
90
+ }
91
+
83
92
 
84
93
  class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
85
94
  """Holds asynchronous model output specifically from a TPU runner.
@@ -144,7 +153,6 @@ class ExecuteModelState:
144
153
  spec_decode_metadata: Optional[SpecDecodeMetadata]
145
154
  kv_connector_output: Optional[KVConnectorOutput]
146
155
  logits_indices_selector: Optional[List[int]] = None
147
- padded_num_reqs: Optional[int] = None
148
156
 
149
157
 
150
158
  @functools.partial(jax.jit, donate_argnums=(0, 1, 2))
@@ -182,40 +190,12 @@ def _substitute_placeholder_token(
182
190
  return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
183
191
 
184
192
 
185
- def _jax_logprobs_to_lists(logprobs_tensors,
186
- logits_indices_selector=None,
187
- cu_num_generated_tokens=None):
188
- """Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
189
- log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
190
- logprobs_list = logprobs_tensors.logprobs.tolist()
191
- selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
192
-
193
- if logits_indices_selector is not None:
194
- log_token_ids_list = [
195
- log_token_ids_list[i] for i in logits_indices_selector
196
- ]
197
- logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
198
- selected_token_ranks_list = [
199
- selected_token_ranks_list[i] for i in logits_indices_selector
200
- ]
201
-
202
- return LogprobsLists(
203
- logprob_token_ids=np.asarray(log_token_ids_list),
204
- logprobs=np.asarray(logprobs_list),
205
- sampled_token_ranks=np.asarray(selected_token_ranks_list),
206
- cu_num_generated_tokens=cu_num_generated_tokens,
207
- )
208
-
209
-
210
193
  class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
211
194
 
212
195
  def __init__(
213
196
  self,
214
197
  vllm_config: VllmConfig,
215
198
  devices: List[Any],
216
- rank: int = 0,
217
- is_first_rank: bool = True,
218
- is_last_rank: bool = True,
219
199
  ):
220
200
  self.vllm_config = vllm_config
221
201
  self.model_config = vllm_config.model_config
@@ -234,9 +214,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
234
214
  self.maybe_forbid_compile = runner_utils.ForbidCompile(
235
215
  ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
236
216
  self.dp_size = self.vllm_config.sharding_config.total_dp_size
237
- self.rank = rank
238
- self.is_first_rank = is_first_rank
239
- self.is_last_rank = is_last_rank
240
217
 
241
218
  self._init_random()
242
219
  self._init_mesh()
@@ -247,21 +224,31 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
247
224
 
248
225
  # Delegate functions to specific manager classes.
249
226
  self.compilation_manager = CompilationManager(self)
250
- if self.is_last_rank:
251
- self.speculative_decoding_manager = SpeculativeDecodingManager(
252
- self)
253
- self.structured_decoding_manager = StructuredDecodingManager(self)
227
+ self.speculative_decoding_manager = SpeculativeDecodingManager(self)
228
+ self.structured_decoding_manager = StructuredDecodingManager(self)
254
229
  self.kv_cache_manager = KVCacheManager(self)
255
230
  self.mm_manager = MultiModalManager(self)
256
231
  self.persistent_batch_manager = PersistentBatchManager(
257
232
  self.requests, self.input_batch, self.encoder_cache,
258
- self.uses_mrope, self.model_config, self.is_last_rank)
233
+ self.uses_mrope, self.model_config)
259
234
  self.lora_utils = LoraUtils(self)
260
235
 
261
- cache_dtype = self.cache_config.cache_dtype
262
- if cache_dtype == "auto":
263
- cache_dtype = self.dtype
264
- self.kv_cache_dtype = to_torch_dtype(cache_dtype)
236
+ cache_config = self.cache_config
237
+ if cache_config.cache_dtype == "auto":
238
+ model_dtype = self.dtype
239
+ if isinstance(model_dtype, str):
240
+ self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
241
+ elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
242
+ self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
243
+ elif isinstance(model_dtype, torch.dtype):
244
+ self.kv_cache_dtype = model_dtype
245
+ else:
246
+ raise ValueError(
247
+ "KV cache is unsupported for model_dtype of %s",
248
+ model_dtype)
249
+ else:
250
+ self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
251
+ cache_config.cache_dtype]
265
252
 
266
253
  self._pre_async_results: AsyncPreResults | None = None
267
254
  self._substitute_placeholder_token_fn = _substitute_placeholder_token
@@ -275,7 +262,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
275
262
  self.rng_key = jax.random.key(self.model_config.seed)
276
263
 
277
264
  def _init_mesh(self) -> None:
278
- if envs.NEW_MODEL_DESIGN:
265
+ if os.getenv("NEW_MODEL_DESIGN", False):
279
266
  self.mesh = self._create_new_model_mesh()
280
267
  else:
281
268
  # NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
@@ -286,7 +273,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
286
273
  logger.info(f"Init mesh | mesh={self.mesh}")
287
274
 
288
275
  def _create_new_model_mesh(self) -> jax.sharding.Mesh:
289
- num_slices = envs.NUM_SLICES
276
+ num_slices = int(os.environ.get('NUM_SLICES', 1))
290
277
 
291
278
  logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
292
279
  f"num_slices={num_slices}")
@@ -355,7 +342,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
355
342
  devices=self.devices)
356
343
 
357
344
  def _init_phased_profiling(self) -> None:
358
- self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
345
+ self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
359
346
  self.phase_based_profiler = None
360
347
  if self.phased_profiling_dir:
361
348
  self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
@@ -397,7 +384,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
397
384
  min_token_size=max(16, self.dp_size),
398
385
  max_token_size=scheduler_config.max_num_batched_tokens *
399
386
  self.dp_size,
400
- padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
387
+ padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
401
388
  self.num_tokens_paddings_per_dp = [
402
389
  padding // self.dp_size for padding in self.num_tokens_paddings
403
390
  ]
@@ -421,14 +408,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
421
408
 
422
409
  self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
423
410
  self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
424
- # Note: self.input_batch and self.block_tables_cpu are both initialized
425
- # with only 1 block_size. For hybrid kv cache, it will be re-init
426
- # in kv_cache_manager's maybe_reinitialize_input_batch.
427
- self.block_tables_cpu = [
428
- np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
429
- dtype=np.int32)
430
- ]
431
-
411
+ self.block_table_cpu = np.zeros(
412
+ (self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
432
413
  self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
433
414
  dtype=np.int32)
434
415
  self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
@@ -462,6 +443,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
462
443
 
463
444
  # tensors for structured decoding
464
445
  self.vocab_size = self.model_config.get_vocab_size()
446
+ if self.lora_config is not None:
447
+ # lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
448
+ self.vocab_size += self.lora_config.lora_extra_vocab_size
465
449
  self.grammar_bitmask_cpu = np.zeros(
466
450
  (self.max_num_reqs, cdiv(self.vocab_size, 32)),
467
451
  dtype=np.int32,
@@ -506,14 +490,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
506
490
 
507
491
  self.rng_params_for_sampling = nnx.Rngs(
508
492
  jax.random.key(self.model_config.seed)).params()
509
- self.is_multimodal_model = (
510
- self.model_config.is_multimodal_model
511
- and self.get_multimodal_embeddings_fn is not None and hasattr(
512
- self.model_config.hf_config, "architectures"
513
- ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
514
- and len(self.model_config.hf_config.architectures) >= 1
515
- and self.model_config.hf_config.architectures[0]
516
- != "Llama4ForConditionalGeneration")
493
+ self.is_multimodal_model = (self.model_config.is_multimodal_model
494
+ and self.get_multimodal_embeddings_fn
495
+ is not None)
517
496
 
518
497
  logger.info(f"Init model | "
519
498
  f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
@@ -526,7 +505,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
526
505
 
527
506
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
528
507
  self.kv_cache_config = kv_cache_config
529
- self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
530
508
  self.kv_caches = []
531
509
  self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
532
510
  if has_kv_transfer_group():
@@ -539,12 +517,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
539
517
  def execute_model(
540
518
  self,
541
519
  scheduler_output: "VllmSchedulerOutput",
542
- intermediate_tensors: Optional[JaxIntermediateTensors] = None,
543
- ) -> ModelRunnerOutput | JaxIntermediateTensors | None:
520
+ intermediate_tensors: Optional[IntermediateTensors] = None,
521
+ ) -> ModelRunnerOutput | None:
544
522
  if self.execute_model_state is not None:
545
523
  raise RuntimeError("State error: sample_tokens() must be called "
546
524
  "after execute_model() returns None.")
547
- _, output = self._execute_model(scheduler_output, intermediate_tensors)
525
+ _, output = self._execute_model(scheduler_output)
548
526
  return output
549
527
 
550
528
  def sample_tokens(
@@ -557,17 +535,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
557
535
 
558
536
  (scheduler_output, attn_metadata, input_ids, hidden_states, logits,
559
537
  aux_hidden_states, spec_decode_metadata, kv_connector_output,
560
- logits_indices_selector,
561
- padded_num_reqs) = (self.execute_model_state.scheduler_output,
562
- self.execute_model_state.attn_metadata,
563
- self.execute_model_state.input_ids,
564
- self.execute_model_state.hidden_states,
565
- self.execute_model_state.logits,
566
- self.execute_model_state.aux_hidden_states,
567
- self.execute_model_state.spec_decode_metadata,
568
- self.execute_model_state.kv_connector_output,
569
- self.execute_model_state.logits_indices_selector,
570
- self.execute_model_state.padded_num_reqs)
538
+ logits_indices_selector) = (
539
+ self.execute_model_state.scheduler_output,
540
+ self.execute_model_state.attn_metadata,
541
+ self.execute_model_state.input_ids,
542
+ self.execute_model_state.hidden_states,
543
+ self.execute_model_state.logits,
544
+ self.execute_model_state.aux_hidden_states,
545
+ self.execute_model_state.spec_decode_metadata,
546
+ self.execute_model_state.kv_connector_output,
547
+ self.execute_model_state.logits_indices_selector)
571
548
  self.execute_model_state = None
572
549
 
573
550
  if grammar_output is not None:
@@ -581,10 +558,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
581
558
  logits,
582
559
  arange,
583
560
  )
584
- return self._sample_from_logits(
585
- scheduler_output, attn_metadata, input_ids, hidden_states, logits,
586
- aux_hidden_states, spec_decode_metadata, kv_connector_output,
587
- logits_indices_selector, padded_num_reqs)
561
+ return self._sample_from_logits(scheduler_output, attn_metadata,
562
+ input_ids, hidden_states, logits,
563
+ aux_hidden_states,
564
+ spec_decode_metadata,
565
+ kv_connector_output,
566
+ logits_indices_selector)
588
567
 
589
568
  def _modify_prev_results(self):
590
569
  # If copy to host has not been done, we just wait.
@@ -670,9 +649,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
670
649
  def _execute_model(
671
650
  self,
672
651
  scheduler_output: "VllmSchedulerOutput",
673
- intermediate_tensors: Optional[JaxIntermediateTensors] = None,
674
- ) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
675
- | None]:
652
+ ) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
676
653
  self.persistent_batch_manager.update_states(
677
654
  scheduler_output, self.get_mrope_input_positions_fn)
678
655
  if not scheduler_output.total_num_scheduled_tokens:
@@ -695,23 +672,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
695
672
  # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
696
673
  (
697
674
  input_ids,
698
- input_positions,
699
675
  attn_metadata,
700
676
  _,
701
677
  logits_indices,
702
678
  spec_decode_metadata,
703
679
  logits_indices_selector,
704
- padded_num_reqs,
705
680
  ) = self._prepare_inputs(scheduler_output)
706
681
 
707
- is_llama_guard_4 = (
708
- hasattr(
709
- self.model_config.hf_config, "architectures"
710
- ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
711
- and len(self.model_config.hf_config.architectures) >= 1
712
- and self.model_config.hf_config.architectures[0]
713
- == "Llama4ForConditionalGeneration")
714
-
715
682
  # multi-modal support
716
683
  if self.is_multimodal_model:
717
684
  # Run the multimodal encoder if any.
@@ -719,13 +686,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
719
686
  self.mm_manager.execute_mm_encoder(scheduler_output)
720
687
  mm_embeds = self.mm_manager.gather_mm_embeddings(
721
688
  scheduler_output, input_ids.shape[0])
722
- #TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
723
- elif is_llama_guard_4 and any(
724
- self.mm_manager.runner.requests[req_id].mm_features
725
- for req_id in self.mm_manager.runner.input_batch.req_ids):
726
- raise NotImplementedError(
727
- "Llama Guard 4 (JAX) currently supports only text inputs. "
728
- "Multimodal processing not yet implemented.")
729
689
  else:
730
690
  mm_embeds = []
731
691
 
@@ -750,6 +710,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
750
710
  scheduler_output) as kv_connector_output:
751
711
  # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
752
712
  # but one of them would be `None`
713
+
753
714
  (self.kv_caches, hidden_states,
754
715
  aux_hidden_states) = self.model_fn(
755
716
  self.state,
@@ -757,17 +718,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
757
718
  input_ids,
758
719
  attn_metadata,
759
720
  inputs_embeds,
760
- input_positions,
761
721
  tuple(self.layer_name_to_kvcache_index.items()),
762
722
  lora_metadata,
763
- intermediate_tensors,
764
- self.is_first_rank,
765
- self.is_last_rank,
766
723
  )
767
- if not get_pp_group().is_last_rank:
768
- assert isinstance(hidden_states, JaxIntermediateTensors)
769
- hidden_states.kv_connector_output = kv_connector_output
770
- return attn_metadata, hidden_states
724
+
771
725
  hidden_states = self._select_from_array_fn(hidden_states,
772
726
  logits_indices)
773
727
  logits = self.compute_logits_fn(
@@ -785,8 +739,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
785
739
  aux_hidden_states=aux_hidden_states,
786
740
  spec_decode_metadata=spec_decode_metadata,
787
741
  kv_connector_output=kv_connector_output,
788
- logits_indices_selector=logits_indices_selector,
789
- padded_num_reqs=padded_num_reqs)
742
+ logits_indices_selector=logits_indices_selector)
790
743
  return attn_metadata, None
791
744
 
792
745
  def _sample_from_logits(
@@ -800,44 +753,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
800
753
  spec_decode_metadata: Optional[SpecDecodeMetadata],
801
754
  kv_connector_output: Optional[KVConnectorOutput],
802
755
  logits_indices_selector: Optional[List[int]] = None,
803
- padded_num_reqs: Optional[int] = None,
804
756
  ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
805
- if padded_num_reqs is None:
806
- padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
807
- self.input_batch.num_reqs, self.max_num_reqs)
808
-
809
- sharding = None
810
- if self.dp_size > 1:
811
- sharding = NamedSharding(self.mesh,
812
- PartitionSpec(ShardingAxisName.ATTN_DATA))
813
-
757
+ padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
758
+ self.input_batch.num_reqs, self.max_num_reqs)
814
759
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
815
- self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
816
-
817
- # TODO(pooyam): Should we move this to `_prepare_inputs`?
818
- if tpu_sampling_metadata.do_sampling:
819
- self.rng_params_for_sampling, step_rng = jax.random.split(
820
- self.rng_params_for_sampling)
821
- else:
822
- step_rng = self.rng_params_for_sampling
823
-
760
+ self.mesh, self.input_batch, padded_num_reqs)
824
761
  if spec_decode_metadata is None:
825
762
  next_tokens = sample(
826
- step_rng,
763
+ self.rng_params_for_sampling,
827
764
  self.mesh,
828
765
  logits,
829
766
  tpu_sampling_metadata,
830
767
  )
831
768
  else:
832
- if tpu_sampling_metadata.do_sampling:
833
- bonus_rng, rejection_rng = jax.random.split(step_rng)
834
- else:
835
- bonus_rng = step_rng
836
- rejection_rng = step_rng
837
769
  bonus_logits = self._select_from_array_fn(
838
770
  logits, spec_decode_metadata.bonus_logits_indices)
839
771
  bonus_token_ids = sample(
840
- bonus_rng,
772
+ self.rng_params_for_sampling,
841
773
  self.mesh,
842
774
  bonus_logits,
843
775
  tpu_sampling_metadata,
@@ -851,7 +783,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
851
783
  target_logits=target_logits,
852
784
  bonus_token_ids=bonus_token_ids,
853
785
  sampling_metadata=tpu_sampling_metadata,
854
- key=rejection_rng,
786
+ key=self.rng_params_for_sampling,
855
787
  )
856
788
 
857
789
  if tpu_sampling_metadata.logprobs:
@@ -908,10 +840,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
908
840
  logits_indices_selector)
909
841
 
910
842
  if logprobs is not None:
911
- # Map logprobs back to the pre-dp shuffling order
912
- logprobs_lists = _jax_logprobs_to_lists(
913
- logprobs, logits_indices_selector)
914
-
843
+ logprobs_lists = logprobs.tolists()
915
844
  else:
916
845
  logprobs_lists = None
917
846
 
@@ -979,9 +908,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
979
908
  req_state.output_token_ids.extend(sampled_ids)
980
909
 
981
910
  if logprobs is not None:
982
- # Map logprobs back to the pre-dp shuffling order
983
- logprobs_lists = _jax_logprobs_to_lists(logprobs,
984
- logits_indices_selector)
911
+ logprobs_lists = logprobs.tolists()
985
912
  else:
986
913
  logprobs_lists = None
987
914
 
@@ -1329,6 +1256,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1329
1256
  mrope_positions = self.mrope_positions_cpu[:, :
1330
1257
  padded_total_num_scheduled_tokens]
1331
1258
 
1259
+ block_tables = self.block_table_cpu[:self.max_num_reqs]
1260
+ for dp_rank in range(dp_size):
1261
+ req_offset = dp_rank * max_num_reqs_per_dp_rank
1262
+ _num_reqs = num_req_per_dp_rank[dp_rank]
1263
+
1264
+ block_tables[
1265
+ req_offset:req_offset + _num_reqs, :self.
1266
+ max_num_blocks_per_req] = self.input_batch.block_table[
1267
+ 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1268
+
1332
1269
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
1333
1270
  dp_size]
1334
1271
  seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
@@ -1370,59 +1307,20 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1370
1307
  if self.uses_mrope:
1371
1308
  positions = mrope_positions
1372
1309
 
1310
+ # Convert block_tables to 1D on cpu.
1311
+ block_tables = block_tables.reshape(-1)
1312
+
1373
1313
  query_start_loc_cpu = query_start_loc
1374
1314
  logits_indices_cpu = logits_indices
1375
1315
  seq_lens_cpu = seq_lens
1376
1316
 
1377
- (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1378
- request_distribution) = device_array(
1317
+ (input_ids, positions, block_tables, query_start_loc, seq_lens,
1318
+ logits_indices, request_distribution, logits_indices) = device_array(
1379
1319
  self.mesh,
1380
- (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1381
- request_distribution),
1320
+ (input_ids, positions, block_tables, query_start_loc, seq_lens,
1321
+ logits_indices, request_distribution, logits_indices),
1382
1322
  sharding=data_parallel_attn_sharding,
1383
1323
  )
1384
-
1385
- attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1386
- uniform_attention_metadata: AttentionMetadata = None
1387
- for kv_cache_gid, kv_cache_group in enumerate(
1388
- self.kv_cache_config.kv_cache_groups):
1389
- block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1390
- max_num_reqs]
1391
- for dp_rank in range(dp_size):
1392
- req_offset = dp_rank * max_num_reqs_per_dp_rank
1393
- _num_reqs = num_req_per_dp_rank[dp_rank]
1394
-
1395
- block_tables[
1396
- req_offset:req_offset + _num_reqs, :self.
1397
- max_num_blocks_per_req] = self.input_batch.block_table[
1398
- 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1399
- # Convert block_tables to 1D on cpu.
1400
- block_tables = block_tables.reshape(-1)
1401
- block_tables = device_array(
1402
- self.mesh,
1403
- (block_tables),
1404
- sharding=data_parallel_attn_sharding,
1405
- )
1406
-
1407
- attention_metadata_gid = AttentionMetadata(
1408
- input_positions=positions,
1409
- block_tables=block_tables,
1410
- seq_lens=seq_lens,
1411
- query_start_loc=query_start_loc,
1412
- request_distribution=request_distribution,
1413
- )
1414
-
1415
- # This is for making these cpu buffers hidden during tracing
1416
- attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1417
- attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1418
-
1419
- if not self.use_hybrid_kvcache:
1420
- uniform_attention_metadata = attention_metadata_gid
1421
- else:
1422
- for layer_name in kv_cache_group.layer_names:
1423
- attention_metadata_per_layer[
1424
- layer_name] = attention_metadata_gid
1425
-
1426
1324
  # Async scheduling: substitute placeholder tokens for DP
1427
1325
  if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
1428
1326
  # Collect all token indices that need substitution across all DP ranks
@@ -1451,19 +1349,25 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1451
1349
  padded_total_num_scheduled_tokens,
1452
1350
  )
1453
1351
 
1454
- if self.use_hybrid_kvcache:
1455
- attention_metadata = attention_metadata_per_layer
1456
- else:
1457
- attention_metadata = uniform_attention_metadata
1352
+ attention_metadata = AttentionMetadata(
1353
+ input_positions=positions,
1354
+ block_tables=block_tables,
1355
+ seq_lens=seq_lens,
1356
+ query_start_loc=query_start_loc,
1357
+ request_distribution=request_distribution,
1358
+ )
1359
+
1360
+ # This is for making these cpu buffers hidden during tracing
1361
+ attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1362
+ attention_metadata.seq_lens_cpu = seq_lens_cpu
1363
+
1458
1364
  return (
1459
1365
  input_ids,
1460
- positions,
1461
1366
  attention_metadata,
1462
1367
  sampling_metadata,
1463
1368
  logits_indices,
1464
1369
  spec_decode_metadata,
1465
1370
  logits_indices_selector,
1466
- padded_num_reqs,
1467
1371
  )
1468
1372
 
1469
1373
  def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
@@ -1564,6 +1468,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1564
1468
  positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
1565
1469
  mrope_positions = self.mrope_positions_cpu[:, :
1566
1470
  padded_total_num_scheduled_tokens]
1471
+ block_tables = self.block_table_cpu[:self.max_num_reqs]
1472
+ block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
1473
+ self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
1567
1474
 
1568
1475
  # TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
1569
1476
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
@@ -1592,44 +1499,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1592
1499
  self.mesh, self.input_batch, padded_num_reqs)
1593
1500
  if self.uses_mrope:
1594
1501
  positions = mrope_positions
1502
+
1503
+ # Convert block_tables to 1D on cpu.
1504
+ block_tables = block_tables.reshape(-1)
1505
+
1595
1506
  query_start_loc_cpu = query_start_loc
1596
1507
  seq_lens_cpu = seq_lens
1597
-
1598
- (input_ids, positions, query_start_loc, seq_lens,
1508
+ (input_ids, positions, block_tables, query_start_loc, seq_lens,
1599
1509
  logits_indices, request_distribution) = device_array(
1600
- self.mesh, (input_ids, positions, query_start_loc, seq_lens,
1601
- logits_indices, request_distribution))
1602
-
1603
- attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1604
- uniform_attention_metadata: AttentionMetadata = None
1605
- for kv_cache_gid, kv_cache_group in enumerate(
1606
- self.kv_cache_config.kv_cache_groups):
1607
- block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1608
- max_num_reqs]
1609
- block_tables[:num_reqs] = (
1610
- self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
1611
- [:num_reqs])
1612
- # Convert block_tables to 1D on cpu.
1613
- block_tables = block_tables.reshape(-1)
1614
- block_tables = device_array(self.mesh, (block_tables))
1615
-
1616
- attention_metadata_gid = AttentionMetadata(
1617
- input_positions=positions,
1618
- block_tables=block_tables,
1619
- seq_lens=seq_lens,
1620
- query_start_loc=query_start_loc,
1621
- request_distribution=request_distribution)
1622
- # This is for making these cpu buffers hidden during tracing
1623
- attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1624
- attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1625
-
1626
- if not self.use_hybrid_kvcache:
1627
- # all layers share the same attention metadata
1628
- uniform_attention_metadata = attention_metadata_gid
1629
- else:
1630
- for layer_name in kv_cache_group.layer_names:
1631
- attention_metadata_per_layer[
1632
- layer_name] = attention_metadata_gid
1510
+ self.mesh, (input_ids, positions, block_tables, query_start_loc,
1511
+ seq_lens, logits_indices, request_distribution))
1633
1512
 
1634
1513
  if self.scheduler_config.async_scheduling and len(
1635
1514
  token_in_tpu_cur_input_indices) > 0:
@@ -1642,15 +1521,20 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1642
1521
  self.lora_utils.set_active_loras(
1643
1522
  num_scheduled_tokens_per_req, total_num_scheduled_tokens,
1644
1523
  padded_total_num_scheduled_tokens)
1645
- logits_indices_selector = None
1646
1524
 
1647
- if self.use_hybrid_kvcache:
1648
- attention_metadata = attention_metadata_per_layer
1649
- else:
1650
- attention_metadata = uniform_attention_metadata
1651
- return (input_ids, positions, attention_metadata, sampling_metadata,
1652
- logits_indices, spec_decode_metadata, logits_indices_selector,
1653
- padded_num_reqs)
1525
+ attention_metadata = AttentionMetadata(
1526
+ input_positions=positions,
1527
+ block_tables=block_tables,
1528
+ seq_lens=seq_lens,
1529
+ query_start_loc=query_start_loc,
1530
+ request_distribution=request_distribution)
1531
+
1532
+ # This is for making these cpu buffers hidden during tracing
1533
+ attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1534
+ attention_metadata.seq_lens_cpu = seq_lens_cpu
1535
+ logits_indices_selector = None
1536
+ return (input_ids, attention_metadata, sampling_metadata,
1537
+ logits_indices, spec_decode_metadata, logits_indices_selector)
1654
1538
 
1655
1539
  def _get_input_ids_embeds(self, input_ids: jax.Array,
1656
1540
  mm_embeds: list[jax.Array]):
@@ -1710,35 +1594,3 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1710
1594
  mappings=mappings,
1711
1595
  transpose_keys=transpose_keys,
1712
1596
  shard=shard)
1713
-
1714
- def get_intermediate_tensor_spec(self, num_tokens: int):
1715
- impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
1716
- jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
1717
- num_padded_tokens = runner_utils.get_padded_token_len(
1718
- self.num_tokens_paddings, num_tokens)
1719
- sharding = NamedSharding(self.mesh, PartitionSpec())
1720
- hidden_size = self.model_config.get_hidden_size()
1721
- spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
1722
- dtype=jax_dtype,
1723
- sharding=sharding)
1724
- tensor_spec = {"hidden_states": spec, "residual": spec}
1725
- return tensor_spec
1726
-
1727
- def get_uuid_for_jax_transfer(self,
1728
- scheduler_output: "VllmSchedulerOutput",
1729
- rank: int, step: int) -> int:
1730
- '''
1731
- Get a uuid for jax.transfer, here we use the hash of
1732
- scheduler_output + counter_step + sender's rank
1733
- '''
1734
- scheduler_output_str = ""
1735
- if not scheduler_output.num_scheduled_tokens:
1736
- scheduler_output_str = "empty_batch"
1737
- else:
1738
- scheduler_output_str = str(
1739
- sorted(scheduler_output.num_scheduled_tokens.items()))
1740
- unique_str = f'{scheduler_output_str} {step} {rank}'
1741
- import hashlib
1742
- hasher = hashlib.sha1()
1743
- hasher.update(unique_str.encode('utf-8'))
1744
- return int.from_bytes(hasher.digest()[:8], 'big')