tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (58) hide show
  1. tests/lora/test_layers.py +0 -6
  2. tests/lora/utils.py +0 -8
  3. tests/test_envs.py +182 -0
  4. tests/test_utils.py +23 -14
  5. tpu_inference/__init__.py +22 -3
  6. tpu_inference/core/core_tpu.py +17 -9
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +2 -3
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +1 -1
  11. tpu_inference/executors/ray_distributed_executor.py +27 -11
  12. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +110 -64
  14. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
  15. tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
  16. tpu_inference/layers/common/quant_methods.py +8 -0
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  19. tpu_inference/layers/jax/sample/sampling.py +2 -2
  20. tpu_inference/layers/vllm/attention.py +1 -1
  21. tpu_inference/layers/vllm/quantization/__init__.py +7 -3
  22. tpu_inference/layers/vllm/quantization/awq.py +4 -3
  23. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
  24. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  25. tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
  26. tpu_inference/layers/vllm/sharding.py +2 -2
  27. tpu_inference/lora/torch_punica_tpu.py +1 -2
  28. tpu_inference/models/common/model_loader.py +12 -11
  29. tpu_inference/models/jax/llama3.py +4 -3
  30. tpu_inference/models/jax/llama_eagle3.py +9 -5
  31. tpu_inference/models/jax/llama_guard_4.py +361 -0
  32. tpu_inference/models/jax/qwen2.py +3 -2
  33. tpu_inference/models/jax/qwen2_5_vl.py +4 -3
  34. tpu_inference/models/jax/qwen3.py +3 -2
  35. tpu_inference/models/jax/utils/weight_utils.py +21 -8
  36. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
  37. tpu_inference/platforms/tpu_platform.py +17 -7
  38. tpu_inference/runner/compilation_manager.py +37 -17
  39. tpu_inference/runner/kv_cache.py +1 -1
  40. tpu_inference/runner/kv_cache_manager.py +8 -2
  41. tpu_inference/runner/tpu_runner.py +199 -87
  42. tpu_inference/spec_decode/jax/eagle3.py +2 -1
  43. tpu_inference/tpu_info.py +4 -3
  44. tpu_inference/utils.py +7 -6
  45. tpu_inference/worker/tpu_worker.py +159 -23
  46. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
  47. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +52 -54
  48. tpu_inference/mock/__init__.py +0 -0
  49. tpu_inference/mock/vllm_config_utils.py +0 -28
  50. tpu_inference/mock/vllm_envs.py +0 -1219
  51. tpu_inference/mock/vllm_logger.py +0 -212
  52. tpu_inference/mock/vllm_logging_utils.py +0 -15
  53. tpu_inference/models/jax/phi3.py +0 -376
  54. /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
  55. /tpu_inference/layers/{jax → common}/sharding.py +0 -0
  56. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
  57. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
  58. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import time
3
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
3
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
@@ -10,10 +10,10 @@ from jax.sharding import NamedSharding, PartitionSpec
10
10
 
11
11
  from tpu_inference.core.disagg_utils import is_disagg_enabled
12
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
+ from tpu_inference.layers.common.sharding import ShardingAxisName
13
14
  from tpu_inference.layers.jax.sample.sampling import sample
14
15
  from tpu_inference.layers.jax.sample.sampling_metadata import \
15
16
  TPUSupportedSamplingMetadata
16
- from tpu_inference.layers.jax.sharding import ShardingAxisName
17
17
  from tpu_inference.logger import init_logger
18
18
  from tpu_inference.utils import device_array
19
19
 
@@ -135,12 +135,6 @@ class CompilationManager:
135
135
  ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
136
136
 
137
137
  # Keep existing pattern for complex array operations
138
- block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
139
- block_tables = block_tables.reshape(-1)
140
- block_tables = device_array(self.runner.mesh,
141
- block_tables,
142
- sharding=dp_sharding)
143
-
144
138
  seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
145
139
  jnp.int32, dp_sharding)
146
140
  query_start_loc = self._create_dummy_tensor(
@@ -152,26 +146,45 @@ class CompilationManager:
152
146
  request_distribution,
153
147
  sharding=dp_sharding)
154
148
 
155
- attention_metadata = AttentionMetadata(
156
- input_positions=positions,
157
- block_tables=block_tables,
158
- seq_lens=seq_lens,
159
- query_start_loc=query_start_loc,
160
- request_distribution=request_distribution,
161
- )
149
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
150
+ uniform_attention_metadata: AttentionMetadata = None
151
+ for kv_cache_gid, kv_cache_group in enumerate(
152
+ self.runner.kv_cache_config.kv_cache_groups):
153
+ block_tables = self.runner.block_tables_cpu[
154
+ kv_cache_gid][:self.runner.max_num_reqs]
155
+ block_tables = block_tables.reshape(-1)
156
+ block_tables = device_array(self.runner.mesh,
157
+ block_tables,
158
+ sharding=dp_sharding)
159
+
160
+ attention_metadata_gid = AttentionMetadata(
161
+ input_positions=positions,
162
+ block_tables=block_tables,
163
+ seq_lens=seq_lens,
164
+ query_start_loc=query_start_loc,
165
+ request_distribution=request_distribution,
166
+ )
167
+ if not self.runner.use_hybrid_kvcache:
168
+ # all layers share the same attention metadata
169
+ uniform_attention_metadata = attention_metadata_gid
170
+ else:
171
+ for layer_name in kv_cache_group.layer_names:
172
+ attention_metadata_per_layer[
173
+ layer_name] = attention_metadata_gid
162
174
 
163
175
  def model_fn_wrapper(
164
176
  state,
165
177
  kv_caches,
166
178
  input_ids,
167
179
  attention_metadata,
180
+ positions,
168
181
  inputs_embeds,
169
182
  layer_name_to_kvcache_index,
170
183
  lora_metadata,
171
184
  ):
172
185
  kv_caches, hidden_states, _ = self.runner.model_fn(
173
186
  state, kv_caches, input_ids, attention_metadata, inputs_embeds,
174
- layer_name_to_kvcache_index, lora_metadata)
187
+ positions, layer_name_to_kvcache_index, lora_metadata)
175
188
  self.runner.kv_caches = kv_caches
176
189
  return hidden_states
177
190
 
@@ -179,6 +192,10 @@ class CompilationManager:
179
192
  self.runner.lora_config, np.array([num_tokens],
180
193
  dtype=np.int32)):
181
194
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
195
+ if self.runner.use_hybrid_kvcache:
196
+ attention_metadata = attention_metadata_per_layer
197
+ else:
198
+ attention_metadata = uniform_attention_metadata
182
199
  self._run_compilation(
183
200
  name,
184
201
  model_fn_wrapper,
@@ -186,6 +203,7 @@ class CompilationManager:
186
203
  self.runner.kv_caches,
187
204
  input_ids,
188
205
  attention_metadata,
206
+ positions,
189
207
  inputs_embeds,
190
208
  tuple(self.runner.layer_name_to_kvcache_index.items()),
191
209
  lora_metadata,
@@ -332,13 +350,15 @@ class CompilationManager:
332
350
  index_paddings = self.runner.num_reqs_paddings
333
351
  dp_sharding = NamedSharding(self.runner.mesh,
334
352
  PartitionSpec(ShardingAxisName.ATTN_DATA))
353
+ hidden_states_sharding = NamedSharding(
354
+ self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
335
355
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
336
356
  self._precompile_select_from_array_helper(
337
357
  name="select all logits",
338
358
  source_paddings=self.runner.num_tokens_paddings,
339
359
  indices_paddings=index_paddings,
340
360
  hidden_dim=hsize,
341
- input_sharding=dp_sharding,
361
+ input_sharding=hidden_states_sharding,
342
362
  indices_sharding=dp_sharding if dp_size > 1 else None,
343
363
  )
344
364
 
@@ -9,7 +9,7 @@ from torchax.ops.mappings import t2j_dtype
9
9
 
10
10
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
11
11
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
12
- from tpu_inference.layers.jax.sharding import ShardingAxisName
12
+ from tpu_inference.layers.common.sharding import ShardingAxisName
13
13
  from tpu_inference.logger import init_logger
14
14
 
15
15
  logger = init_logger(__name__)
@@ -1,15 +1,16 @@
1
1
  import functools
2
- import math
3
2
  from typing import TYPE_CHECKING, Dict, List
4
3
 
5
4
  import jax
6
5
  import jax.numpy as jnp
6
+ import numpy as np
7
7
  import vllm.envs as envs
8
8
  from jax.sharding import NamedSharding, PartitionSpec
9
9
  from torchax.ops.mappings import t2j_dtype
10
10
  from vllm.attention import Attention
11
11
  from vllm.attention.backends.abstract import AttentionType
12
12
  from vllm.config import get_layers_from_vllm_config
13
+ from vllm.utils.math_utils import cdiv
13
14
  from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
14
15
  KVCacheSpec, MLAAttentionSpec,
15
16
  SlidingWindowSpec)
@@ -175,6 +176,11 @@ class KVCacheManager:
175
176
  )
176
177
  self.runner.input_batch = new_input_batch
177
178
  self.runner.persistent_batch_manager.input_batch = new_input_batch
179
+ self.runner.block_tables_cpu = [
180
+ np.zeros((self.runner.max_num_reqs,
181
+ cdiv(self.runner.max_model_len, block_size)),
182
+ dtype=np.int32) for block_size in block_sizes
183
+ ]
178
184
 
179
185
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
180
186
  self.maybe_reinitialize_input_batch(kv_cache_config)
@@ -190,7 +196,7 @@ class KVCacheManager:
190
196
  num_blocks = kv_cache_tensor.size // page_size_bytes
191
197
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
192
198
  # num_blocks must be a multiple of dp_size
193
- num_blocks = math.ceil(num_blocks / dp_size) * dp_size
199
+ num_blocks = (num_blocks // dp_size) * dp_size
194
200
  # NOTE: we'll multiply the num_kv_heads by 2 in the function
195
201
  kv_cache = create_kv_caches(
196
202
  num_blocks=num_blocks,
@@ -27,7 +27,7 @@ from vllm.v1.core.sched.output import GrammarOutput
27
27
  from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
28
28
  from vllm.v1.kv_cache_interface import KVCacheConfig
29
29
  from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
30
- DraftTokenIds, KVConnectorOutput,
30
+ DraftTokenIds, KVConnectorOutput, LogprobsLists,
31
31
  ModelRunnerOutput)
32
32
  from vllm.v1.request import Request
33
33
  from vllm.v1.spec_decode.ngram_proposer import NgramProposer
@@ -37,15 +37,15 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
37
37
 
38
38
  from tpu_inference import utils as common_utils
39
39
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
+ from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
41
+ MESH_AXIS_NAMES_2D,
42
+ ShardingAxisName,
43
+ ShardingConfigManager)
40
44
  from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
41
45
  from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
42
46
  gather_logprobs, sample)
43
47
  from tpu_inference.layers.jax.sample.sampling_metadata import \
44
48
  TPUSupportedSamplingMetadata
45
- from tpu_inference.layers.jax.sharding import (MESH_AXIS_NAMES,
46
- MESH_AXIS_NAMES_2D,
47
- ShardingAxisName,
48
- ShardingConfigManager)
49
49
  from tpu_inference.logger import init_logger
50
50
  from tpu_inference.models.common.model_loader import get_model
51
51
  from tpu_inference.models.jax.utils.weight_utils import (
@@ -153,6 +153,7 @@ class ExecuteModelState:
153
153
  spec_decode_metadata: Optional[SpecDecodeMetadata]
154
154
  kv_connector_output: Optional[KVConnectorOutput]
155
155
  logits_indices_selector: Optional[List[int]] = None
156
+ padded_num_reqs: Optional[int] = None
156
157
 
157
158
 
158
159
  @functools.partial(jax.jit, donate_argnums=(0, 1, 2))
@@ -190,12 +191,40 @@ def _substitute_placeholder_token(
190
191
  return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
191
192
 
192
193
 
194
+ def _jax_logprobs_to_lists(logprobs_tensors,
195
+ logits_indices_selector=None,
196
+ cu_num_generated_tokens=None):
197
+ """Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
198
+ log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
199
+ logprobs_list = logprobs_tensors.logprobs.tolist()
200
+ selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
201
+
202
+ if logits_indices_selector is not None:
203
+ log_token_ids_list = [
204
+ log_token_ids_list[i] for i in logits_indices_selector
205
+ ]
206
+ logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
207
+ selected_token_ranks_list = [
208
+ selected_token_ranks_list[i] for i in logits_indices_selector
209
+ ]
210
+
211
+ return LogprobsLists(
212
+ logprob_token_ids=np.asarray(log_token_ids_list),
213
+ logprobs=np.asarray(logprobs_list),
214
+ sampled_token_ranks=np.asarray(selected_token_ranks_list),
215
+ cu_num_generated_tokens=cu_num_generated_tokens,
216
+ )
217
+
218
+
193
219
  class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
194
220
 
195
221
  def __init__(
196
222
  self,
197
223
  vllm_config: VllmConfig,
198
224
  devices: List[Any],
225
+ rank: int = 0,
226
+ is_first_rank: bool = True,
227
+ is_last_rank: bool = True,
199
228
  ):
200
229
  self.vllm_config = vllm_config
201
230
  self.model_config = vllm_config.model_config
@@ -408,8 +437,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
408
437
 
409
438
  self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
410
439
  self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
411
- self.block_table_cpu = np.zeros(
412
- (self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
440
+ # Note: self.input_batch and self.block_tables_cpu are both initialized
441
+ # with only 1 block_size. For hybrid kv cache, it will be re-init
442
+ # in kv_cache_manager's maybe_reinitialize_input_batch.
443
+ self.block_tables_cpu = [
444
+ np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
445
+ dtype=np.int32)
446
+ ]
447
+
413
448
  self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
414
449
  dtype=np.int32)
415
450
  self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
@@ -443,9 +478,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
443
478
 
444
479
  # tensors for structured decoding
445
480
  self.vocab_size = self.model_config.get_vocab_size()
446
- if self.lora_config is not None:
447
- # lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
448
- self.vocab_size += self.lora_config.lora_extra_vocab_size
449
481
  self.grammar_bitmask_cpu = np.zeros(
450
482
  (self.max_num_reqs, cdiv(self.vocab_size, 32)),
451
483
  dtype=np.int32,
@@ -490,9 +522,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
490
522
 
491
523
  self.rng_params_for_sampling = nnx.Rngs(
492
524
  jax.random.key(self.model_config.seed)).params()
493
- self.is_multimodal_model = (self.model_config.is_multimodal_model
494
- and self.get_multimodal_embeddings_fn
495
- is not None)
525
+ self.is_multimodal_model = (
526
+ self.model_config.is_multimodal_model
527
+ and self.get_multimodal_embeddings_fn is not None and hasattr(
528
+ self.model_config.hf_config, "architectures"
529
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
530
+ and len(self.model_config.hf_config.architectures) >= 1
531
+ and self.model_config.hf_config.architectures[0]
532
+ != "Llama4ForConditionalGeneration")
496
533
 
497
534
  logger.info(f"Init model | "
498
535
  f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
@@ -505,6 +542,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
505
542
 
506
543
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
507
544
  self.kv_cache_config = kv_cache_config
545
+ self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
508
546
  self.kv_caches = []
509
547
  self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
510
548
  if has_kv_transfer_group():
@@ -535,16 +573,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
535
573
 
536
574
  (scheduler_output, attn_metadata, input_ids, hidden_states, logits,
537
575
  aux_hidden_states, spec_decode_metadata, kv_connector_output,
538
- logits_indices_selector) = (
539
- self.execute_model_state.scheduler_output,
540
- self.execute_model_state.attn_metadata,
541
- self.execute_model_state.input_ids,
542
- self.execute_model_state.hidden_states,
543
- self.execute_model_state.logits,
544
- self.execute_model_state.aux_hidden_states,
545
- self.execute_model_state.spec_decode_metadata,
546
- self.execute_model_state.kv_connector_output,
547
- self.execute_model_state.logits_indices_selector)
576
+ logits_indices_selector,
577
+ padded_num_reqs) = (self.execute_model_state.scheduler_output,
578
+ self.execute_model_state.attn_metadata,
579
+ self.execute_model_state.input_ids,
580
+ self.execute_model_state.hidden_states,
581
+ self.execute_model_state.logits,
582
+ self.execute_model_state.aux_hidden_states,
583
+ self.execute_model_state.spec_decode_metadata,
584
+ self.execute_model_state.kv_connector_output,
585
+ self.execute_model_state.logits_indices_selector,
586
+ self.execute_model_state.padded_num_reqs)
548
587
  self.execute_model_state = None
549
588
 
550
589
  if grammar_output is not None:
@@ -558,12 +597,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
558
597
  logits,
559
598
  arange,
560
599
  )
561
- return self._sample_from_logits(scheduler_output, attn_metadata,
562
- input_ids, hidden_states, logits,
563
- aux_hidden_states,
564
- spec_decode_metadata,
565
- kv_connector_output,
566
- logits_indices_selector)
600
+ return self._sample_from_logits(
601
+ scheduler_output, attn_metadata, input_ids, hidden_states, logits,
602
+ aux_hidden_states, spec_decode_metadata, kv_connector_output,
603
+ logits_indices_selector, padded_num_reqs)
567
604
 
568
605
  def _modify_prev_results(self):
569
606
  # If copy to host has not been done, we just wait.
@@ -672,13 +709,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
672
709
  # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
673
710
  (
674
711
  input_ids,
712
+ input_positions,
675
713
  attn_metadata,
676
714
  _,
677
715
  logits_indices,
678
716
  spec_decode_metadata,
679
717
  logits_indices_selector,
718
+ padded_num_reqs,
680
719
  ) = self._prepare_inputs(scheduler_output)
681
720
 
721
+ is_llama_guard_4 = (
722
+ hasattr(
723
+ self.model_config.hf_config, "architectures"
724
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
725
+ and len(self.model_config.hf_config.architectures) >= 1
726
+ and self.model_config.hf_config.architectures[0]
727
+ == "Llama4ForConditionalGeneration")
728
+
682
729
  # multi-modal support
683
730
  if self.is_multimodal_model:
684
731
  # Run the multimodal encoder if any.
@@ -686,6 +733,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
686
733
  self.mm_manager.execute_mm_encoder(scheduler_output)
687
734
  mm_embeds = self.mm_manager.gather_mm_embeddings(
688
735
  scheduler_output, input_ids.shape[0])
736
+ #TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
737
+ elif is_llama_guard_4 and any(
738
+ self.mm_manager.runner.requests[req_id].mm_features
739
+ for req_id in self.mm_manager.runner.input_batch.req_ids):
740
+ raise NotImplementedError(
741
+ "Llama Guard 4 (JAX) currently supports only text inputs. "
742
+ "Multimodal processing not yet implemented.")
689
743
  else:
690
744
  mm_embeds = []
691
745
 
@@ -718,6 +772,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
718
772
  input_ids,
719
773
  attn_metadata,
720
774
  inputs_embeds,
775
+ input_positions,
721
776
  tuple(self.layer_name_to_kvcache_index.items()),
722
777
  lora_metadata,
723
778
  )
@@ -739,7 +794,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
739
794
  aux_hidden_states=aux_hidden_states,
740
795
  spec_decode_metadata=spec_decode_metadata,
741
796
  kv_connector_output=kv_connector_output,
742
- logits_indices_selector=logits_indices_selector)
797
+ logits_indices_selector=logits_indices_selector,
798
+ padded_num_reqs=padded_num_reqs)
743
799
  return attn_metadata, None
744
800
 
745
801
  def _sample_from_logits(
@@ -753,11 +809,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
753
809
  spec_decode_metadata: Optional[SpecDecodeMetadata],
754
810
  kv_connector_output: Optional[KVConnectorOutput],
755
811
  logits_indices_selector: Optional[List[int]] = None,
812
+ padded_num_reqs: Optional[int] = None,
756
813
  ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
757
- padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
758
- self.input_batch.num_reqs, self.max_num_reqs)
814
+ if padded_num_reqs is None:
815
+ padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
816
+ self.input_batch.num_reqs, self.max_num_reqs)
817
+
818
+ sharding = None
819
+ if self.dp_size > 1:
820
+ sharding = NamedSharding(self.mesh,
821
+ PartitionSpec(ShardingAxisName.ATTN_DATA))
822
+
759
823
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
760
- self.mesh, self.input_batch, padded_num_reqs)
824
+ self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
761
825
  if spec_decode_metadata is None:
762
826
  next_tokens = sample(
763
827
  self.rng_params_for_sampling,
@@ -840,7 +904,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
840
904
  logits_indices_selector)
841
905
 
842
906
  if logprobs is not None:
843
- logprobs_lists = logprobs.tolists()
907
+ # Map logprobs back to the pre-dp shuffling order
908
+ logprobs_lists = _jax_logprobs_to_lists(
909
+ logprobs, logits_indices_selector)
910
+
844
911
  else:
845
912
  logprobs_lists = None
846
913
 
@@ -908,7 +975,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
908
975
  req_state.output_token_ids.extend(sampled_ids)
909
976
 
910
977
  if logprobs is not None:
911
- logprobs_lists = logprobs.tolists()
978
+ # Map logprobs back to the pre-dp shuffling order
979
+ logprobs_lists = _jax_logprobs_to_lists(logprobs,
980
+ logits_indices_selector)
912
981
  else:
913
982
  logprobs_lists = None
914
983
 
@@ -1256,16 +1325,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1256
1325
  mrope_positions = self.mrope_positions_cpu[:, :
1257
1326
  padded_total_num_scheduled_tokens]
1258
1327
 
1259
- block_tables = self.block_table_cpu[:self.max_num_reqs]
1260
- for dp_rank in range(dp_size):
1261
- req_offset = dp_rank * max_num_reqs_per_dp_rank
1262
- _num_reqs = num_req_per_dp_rank[dp_rank]
1263
-
1264
- block_tables[
1265
- req_offset:req_offset + _num_reqs, :self.
1266
- max_num_blocks_per_req] = self.input_batch.block_table[
1267
- 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1268
-
1269
1328
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
1270
1329
  dp_size]
1271
1330
  seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
@@ -1307,20 +1366,59 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1307
1366
  if self.uses_mrope:
1308
1367
  positions = mrope_positions
1309
1368
 
1310
- # Convert block_tables to 1D on cpu.
1311
- block_tables = block_tables.reshape(-1)
1312
-
1313
1369
  query_start_loc_cpu = query_start_loc
1314
1370
  logits_indices_cpu = logits_indices
1315
1371
  seq_lens_cpu = seq_lens
1316
1372
 
1317
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1318
- logits_indices, request_distribution, logits_indices) = device_array(
1373
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1374
+ request_distribution) = device_array(
1319
1375
  self.mesh,
1320
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1321
- logits_indices, request_distribution, logits_indices),
1376
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1377
+ request_distribution),
1322
1378
  sharding=data_parallel_attn_sharding,
1323
1379
  )
1380
+
1381
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1382
+ uniform_attention_metadata: AttentionMetadata = None
1383
+ for kv_cache_gid, kv_cache_group in enumerate(
1384
+ self.kv_cache_config.kv_cache_groups):
1385
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1386
+ max_num_reqs]
1387
+ for dp_rank in range(dp_size):
1388
+ req_offset = dp_rank * max_num_reqs_per_dp_rank
1389
+ _num_reqs = num_req_per_dp_rank[dp_rank]
1390
+
1391
+ block_tables[
1392
+ req_offset:req_offset + _num_reqs, :self.
1393
+ max_num_blocks_per_req] = self.input_batch.block_table[
1394
+ 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1395
+ # Convert block_tables to 1D on cpu.
1396
+ block_tables = block_tables.reshape(-1)
1397
+ block_tables = device_array(
1398
+ self.mesh,
1399
+ (block_tables),
1400
+ sharding=data_parallel_attn_sharding,
1401
+ )
1402
+
1403
+ attention_metadata_gid = AttentionMetadata(
1404
+ input_positions=positions,
1405
+ block_tables=block_tables,
1406
+ seq_lens=seq_lens,
1407
+ query_start_loc=query_start_loc,
1408
+ request_distribution=request_distribution,
1409
+ )
1410
+
1411
+ # This is for making these cpu buffers hidden during tracing
1412
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1413
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1414
+
1415
+ if not self.use_hybrid_kvcache:
1416
+ uniform_attention_metadata = attention_metadata_gid
1417
+ else:
1418
+ for layer_name in kv_cache_group.layer_names:
1419
+ attention_metadata_per_layer[
1420
+ layer_name] = attention_metadata_gid
1421
+
1324
1422
  # Async scheduling: substitute placeholder tokens for DP
1325
1423
  if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
1326
1424
  # Collect all token indices that need substitution across all DP ranks
@@ -1349,25 +1447,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1349
1447
  padded_total_num_scheduled_tokens,
1350
1448
  )
1351
1449
 
1352
- attention_metadata = AttentionMetadata(
1353
- input_positions=positions,
1354
- block_tables=block_tables,
1355
- seq_lens=seq_lens,
1356
- query_start_loc=query_start_loc,
1357
- request_distribution=request_distribution,
1358
- )
1359
-
1360
- # This is for making these cpu buffers hidden during tracing
1361
- attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1362
- attention_metadata.seq_lens_cpu = seq_lens_cpu
1363
-
1450
+ if self.use_hybrid_kvcache:
1451
+ attention_metadata = attention_metadata_per_layer
1452
+ else:
1453
+ attention_metadata = uniform_attention_metadata
1364
1454
  return (
1365
1455
  input_ids,
1456
+ positions,
1366
1457
  attention_metadata,
1367
1458
  sampling_metadata,
1368
1459
  logits_indices,
1369
1460
  spec_decode_metadata,
1370
1461
  logits_indices_selector,
1462
+ padded_num_reqs,
1371
1463
  )
1372
1464
 
1373
1465
  def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
@@ -1468,9 +1560,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1468
1560
  positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
1469
1561
  mrope_positions = self.mrope_positions_cpu[:, :
1470
1562
  padded_total_num_scheduled_tokens]
1471
- block_tables = self.block_table_cpu[:self.max_num_reqs]
1472
- block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
1473
- self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
1474
1563
 
1475
1564
  # TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
1476
1565
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
@@ -1499,16 +1588,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1499
1588
  self.mesh, self.input_batch, padded_num_reqs)
1500
1589
  if self.uses_mrope:
1501
1590
  positions = mrope_positions
1502
-
1503
- # Convert block_tables to 1D on cpu.
1504
- block_tables = block_tables.reshape(-1)
1505
-
1506
1591
  query_start_loc_cpu = query_start_loc
1507
1592
  seq_lens_cpu = seq_lens
1508
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1593
+
1594
+ (input_ids, positions, query_start_loc, seq_lens,
1509
1595
  logits_indices, request_distribution) = device_array(
1510
- self.mesh, (input_ids, positions, block_tables, query_start_loc,
1511
- seq_lens, logits_indices, request_distribution))
1596
+ self.mesh, (input_ids, positions, query_start_loc, seq_lens,
1597
+ logits_indices, request_distribution))
1598
+
1599
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1600
+ uniform_attention_metadata: AttentionMetadata = None
1601
+ for kv_cache_gid, kv_cache_group in enumerate(
1602
+ self.kv_cache_config.kv_cache_groups):
1603
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1604
+ max_num_reqs]
1605
+ block_tables[:num_reqs] = (
1606
+ self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
1607
+ [:num_reqs])
1608
+ # Convert block_tables to 1D on cpu.
1609
+ block_tables = block_tables.reshape(-1)
1610
+ block_tables = device_array(self.mesh, (block_tables))
1611
+
1612
+ attention_metadata_gid = AttentionMetadata(
1613
+ input_positions=positions,
1614
+ block_tables=block_tables,
1615
+ seq_lens=seq_lens,
1616
+ query_start_loc=query_start_loc,
1617
+ request_distribution=request_distribution)
1618
+ # This is for making these cpu buffers hidden during tracing
1619
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1620
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1621
+
1622
+ if not self.use_hybrid_kvcache:
1623
+ # all layers share the same attention metadata
1624
+ uniform_attention_metadata = attention_metadata_gid
1625
+ else:
1626
+ for layer_name in kv_cache_group.layer_names:
1627
+ attention_metadata_per_layer[
1628
+ layer_name] = attention_metadata_gid
1512
1629
 
1513
1630
  if self.scheduler_config.async_scheduling and len(
1514
1631
  token_in_tpu_cur_input_indices) > 0:
@@ -1521,20 +1638,15 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1521
1638
  self.lora_utils.set_active_loras(
1522
1639
  num_scheduled_tokens_per_req, total_num_scheduled_tokens,
1523
1640
  padded_total_num_scheduled_tokens)
1524
-
1525
- attention_metadata = AttentionMetadata(
1526
- input_positions=positions,
1527
- block_tables=block_tables,
1528
- seq_lens=seq_lens,
1529
- query_start_loc=query_start_loc,
1530
- request_distribution=request_distribution)
1531
-
1532
- # This is for making these cpu buffers hidden during tracing
1533
- attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1534
- attention_metadata.seq_lens_cpu = seq_lens_cpu
1535
1641
  logits_indices_selector = None
1536
- return (input_ids, attention_metadata, sampling_metadata,
1537
- logits_indices, spec_decode_metadata, logits_indices_selector)
1642
+
1643
+ if self.use_hybrid_kvcache:
1644
+ attention_metadata = attention_metadata_per_layer
1645
+ else:
1646
+ attention_metadata = uniform_attention_metadata
1647
+ return (input_ids, positions, attention_metadata, sampling_metadata,
1648
+ logits_indices, spec_decode_metadata, logits_indices_selector,
1649
+ padded_num_reqs)
1538
1650
 
1539
1651
  def _get_input_ids_embeds(self, input_ids: jax.Array,
1540
1652
  mm_embeds: list[jax.Array]):
@@ -51,7 +51,8 @@ class Eagle3Proposer:
51
51
  """Loads the draft model."""
52
52
  self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
53
53
  self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
54
- del self.state.model['embed_tokens']
54
+ if 'embed_tokens' in self.state.model:
55
+ del self.state.model['embed_tokens']
55
56
  self.state.model.embed_tokens = target_model.model.embed
56
57
 
57
58
  @functools.partial(jax.jit, static_argnums=(0, ))