tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (40) hide show
  1. tests/lora/test_layers.py +0 -6
  2. tests/lora/utils.py +0 -8
  3. tpu_inference/__init__.py +22 -3
  4. tpu_inference/core/disagg_utils.py +6 -8
  5. tpu_inference/distributed/tpu_connector.py +2 -3
  6. tpu_inference/distributed/utils.py +3 -2
  7. tpu_inference/envs.py +1 -1
  8. tpu_inference/executors/ray_distributed_executor.py +4 -1
  9. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  10. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +77 -54
  11. tpu_inference/layers/vllm/sharding.py +2 -2
  12. tpu_inference/lora/torch_punica_tpu.py +1 -2
  13. tpu_inference/models/common/model_loader.py +9 -9
  14. tpu_inference/models/jax/llama3.py +2 -1
  15. tpu_inference/models/jax/llama_eagle3.py +9 -5
  16. tpu_inference/models/jax/llama_guard_4.py +361 -0
  17. tpu_inference/models/jax/qwen2.py +2 -1
  18. tpu_inference/models/jax/qwen2_5_vl.py +2 -1
  19. tpu_inference/models/jax/qwen3.py +2 -1
  20. tpu_inference/models/jax/utils/weight_utils.py +21 -8
  21. tpu_inference/models/vllm/vllm_model_wrapper.py +4 -4
  22. tpu_inference/platforms/tpu_platform.py +5 -2
  23. tpu_inference/runner/compilation_manager.py +33 -15
  24. tpu_inference/runner/kv_cache_manager.py +8 -2
  25. tpu_inference/runner/tpu_runner.py +187 -99
  26. tpu_inference/spec_decode/jax/eagle3.py +2 -1
  27. tpu_inference/tpu_info.py +4 -3
  28. tpu_inference/utils.py +5 -4
  29. tpu_inference/worker/tpu_worker.py +158 -22
  30. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
  31. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +34 -39
  32. tpu_inference/mock/__init__.py +0 -0
  33. tpu_inference/mock/vllm_config_utils.py +0 -28
  34. tpu_inference/mock/vllm_envs.py +0 -1219
  35. tpu_inference/mock/vllm_logger.py +0 -212
  36. tpu_inference/mock/vllm_logging_utils.py +0 -15
  37. tpu_inference/models/jax/phi3.py +0 -376
  38. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
  39. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
  40. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,16 @@
1
1
  import functools
2
- import math
3
2
  from typing import TYPE_CHECKING, Dict, List
4
3
 
5
4
  import jax
6
5
  import jax.numpy as jnp
6
+ import numpy as np
7
7
  import vllm.envs as envs
8
8
  from jax.sharding import NamedSharding, PartitionSpec
9
9
  from torchax.ops.mappings import t2j_dtype
10
10
  from vllm.attention import Attention
11
11
  from vllm.attention.backends.abstract import AttentionType
12
12
  from vllm.config import get_layers_from_vllm_config
13
+ from vllm.utils.math_utils import cdiv
13
14
  from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
14
15
  KVCacheSpec, MLAAttentionSpec,
15
16
  SlidingWindowSpec)
@@ -175,6 +176,11 @@ class KVCacheManager:
175
176
  )
176
177
  self.runner.input_batch = new_input_batch
177
178
  self.runner.persistent_batch_manager.input_batch = new_input_batch
179
+ self.runner.block_tables_cpu = [
180
+ np.zeros((self.runner.max_num_reqs,
181
+ cdiv(self.runner.max_model_len, block_size)),
182
+ dtype=np.int32) for block_size in block_sizes
183
+ ]
178
184
 
179
185
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
180
186
  self.maybe_reinitialize_input_batch(kv_cache_config)
@@ -190,7 +196,7 @@ class KVCacheManager:
190
196
  num_blocks = kv_cache_tensor.size // page_size_bytes
191
197
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
192
198
  # num_blocks must be a multiple of dp_size
193
- num_blocks = math.ceil(num_blocks / dp_size) * dp_size
199
+ num_blocks = (num_blocks // dp_size) * dp_size
194
200
  # NOTE: we'll multiply the num_kv_heads by 2 in the function
195
201
  kv_cache = create_kv_caches(
196
202
  num_blocks=num_blocks,
@@ -153,6 +153,7 @@ class ExecuteModelState:
153
153
  spec_decode_metadata: Optional[SpecDecodeMetadata]
154
154
  kv_connector_output: Optional[KVConnectorOutput]
155
155
  logits_indices_selector: Optional[List[int]] = None
156
+ padded_num_reqs: Optional[int] = None
156
157
 
157
158
 
158
159
  @functools.partial(jax.jit, donate_argnums=(0, 1, 2))
@@ -190,18 +191,28 @@ def _substitute_placeholder_token(
190
191
  return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
191
192
 
192
193
 
193
- def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
194
+ def _jax_logprobs_to_lists(logprobs_tensors,
195
+ logits_indices_selector=None,
196
+ cu_num_generated_tokens=None):
197
+ """Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
198
+ log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
199
+ logprobs_list = logprobs_tensors.logprobs.tolist()
200
+ selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
201
+
202
+ if logits_indices_selector is not None:
203
+ log_token_ids_list = [
204
+ log_token_ids_list[i] for i in logits_indices_selector
205
+ ]
206
+ logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
207
+ selected_token_ranks_list = [
208
+ selected_token_ranks_list[i] for i in logits_indices_selector
209
+ ]
210
+
194
211
  return LogprobsLists(
195
- logprob_token_ids=[
196
- logprobs_lists.logprob_token_ids[i]
197
- for i in logits_indices_selector
198
- ],
199
- logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
200
- sampled_token_ranks=[
201
- logprobs_lists.sampled_token_ranks[i]
202
- for i in logits_indices_selector
203
- ],
204
- cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
212
+ logprob_token_ids=np.asarray(log_token_ids_list),
213
+ logprobs=np.asarray(logprobs_list),
214
+ sampled_token_ranks=np.asarray(selected_token_ranks_list),
215
+ cu_num_generated_tokens=cu_num_generated_tokens,
205
216
  )
206
217
 
207
218
 
@@ -211,6 +222,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
211
222
  self,
212
223
  vllm_config: VllmConfig,
213
224
  devices: List[Any],
225
+ rank: int = 0,
226
+ is_first_rank: bool = True,
227
+ is_last_rank: bool = True,
214
228
  ):
215
229
  self.vllm_config = vllm_config
216
230
  self.model_config = vllm_config.model_config
@@ -423,8 +437,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
423
437
 
424
438
  self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
425
439
  self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
426
- self.block_table_cpu = np.zeros(
427
- (self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
440
+ # Note: self.input_batch and self.block_tables_cpu are both initialized
441
+ # with only 1 block_size. For hybrid kv cache, it will be re-init
442
+ # in kv_cache_manager's maybe_reinitialize_input_batch.
443
+ self.block_tables_cpu = [
444
+ np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
445
+ dtype=np.int32)
446
+ ]
447
+
428
448
  self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
429
449
  dtype=np.int32)
430
450
  self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
@@ -458,9 +478,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
458
478
 
459
479
  # tensors for structured decoding
460
480
  self.vocab_size = self.model_config.get_vocab_size()
461
- if self.lora_config is not None:
462
- # lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
463
- self.vocab_size += self.lora_config.lora_extra_vocab_size
464
481
  self.grammar_bitmask_cpu = np.zeros(
465
482
  (self.max_num_reqs, cdiv(self.vocab_size, 32)),
466
483
  dtype=np.int32,
@@ -505,9 +522,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
505
522
 
506
523
  self.rng_params_for_sampling = nnx.Rngs(
507
524
  jax.random.key(self.model_config.seed)).params()
508
- self.is_multimodal_model = (self.model_config.is_multimodal_model
509
- and self.get_multimodal_embeddings_fn
510
- is not None)
525
+ self.is_multimodal_model = (
526
+ self.model_config.is_multimodal_model
527
+ and self.get_multimodal_embeddings_fn is not None and hasattr(
528
+ self.model_config.hf_config, "architectures"
529
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
530
+ and len(self.model_config.hf_config.architectures) >= 1
531
+ and self.model_config.hf_config.architectures[0]
532
+ != "Llama4ForConditionalGeneration")
511
533
 
512
534
  logger.info(f"Init model | "
513
535
  f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
@@ -520,6 +542,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
520
542
 
521
543
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
522
544
  self.kv_cache_config = kv_cache_config
545
+ self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
523
546
  self.kv_caches = []
524
547
  self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
525
548
  if has_kv_transfer_group():
@@ -550,16 +573,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
550
573
 
551
574
  (scheduler_output, attn_metadata, input_ids, hidden_states, logits,
552
575
  aux_hidden_states, spec_decode_metadata, kv_connector_output,
553
- logits_indices_selector) = (
554
- self.execute_model_state.scheduler_output,
555
- self.execute_model_state.attn_metadata,
556
- self.execute_model_state.input_ids,
557
- self.execute_model_state.hidden_states,
558
- self.execute_model_state.logits,
559
- self.execute_model_state.aux_hidden_states,
560
- self.execute_model_state.spec_decode_metadata,
561
- self.execute_model_state.kv_connector_output,
562
- self.execute_model_state.logits_indices_selector)
576
+ logits_indices_selector,
577
+ padded_num_reqs) = (self.execute_model_state.scheduler_output,
578
+ self.execute_model_state.attn_metadata,
579
+ self.execute_model_state.input_ids,
580
+ self.execute_model_state.hidden_states,
581
+ self.execute_model_state.logits,
582
+ self.execute_model_state.aux_hidden_states,
583
+ self.execute_model_state.spec_decode_metadata,
584
+ self.execute_model_state.kv_connector_output,
585
+ self.execute_model_state.logits_indices_selector,
586
+ self.execute_model_state.padded_num_reqs)
563
587
  self.execute_model_state = None
564
588
 
565
589
  if grammar_output is not None:
@@ -573,12 +597,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
573
597
  logits,
574
598
  arange,
575
599
  )
576
- return self._sample_from_logits(scheduler_output, attn_metadata,
577
- input_ids, hidden_states, logits,
578
- aux_hidden_states,
579
- spec_decode_metadata,
580
- kv_connector_output,
581
- logits_indices_selector)
600
+ return self._sample_from_logits(
601
+ scheduler_output, attn_metadata, input_ids, hidden_states, logits,
602
+ aux_hidden_states, spec_decode_metadata, kv_connector_output,
603
+ logits_indices_selector, padded_num_reqs)
582
604
 
583
605
  def _modify_prev_results(self):
584
606
  # If copy to host has not been done, we just wait.
@@ -687,13 +709,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
687
709
  # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
688
710
  (
689
711
  input_ids,
712
+ input_positions,
690
713
  attn_metadata,
691
714
  _,
692
715
  logits_indices,
693
716
  spec_decode_metadata,
694
717
  logits_indices_selector,
718
+ padded_num_reqs,
695
719
  ) = self._prepare_inputs(scheduler_output)
696
720
 
721
+ is_llama_guard_4 = (
722
+ hasattr(
723
+ self.model_config.hf_config, "architectures"
724
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
725
+ and len(self.model_config.hf_config.architectures) >= 1
726
+ and self.model_config.hf_config.architectures[0]
727
+ == "Llama4ForConditionalGeneration")
728
+
697
729
  # multi-modal support
698
730
  if self.is_multimodal_model:
699
731
  # Run the multimodal encoder if any.
@@ -701,6 +733,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
701
733
  self.mm_manager.execute_mm_encoder(scheduler_output)
702
734
  mm_embeds = self.mm_manager.gather_mm_embeddings(
703
735
  scheduler_output, input_ids.shape[0])
736
+ #TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
737
+ elif is_llama_guard_4 and any(
738
+ self.mm_manager.runner.requests[req_id].mm_features
739
+ for req_id in self.mm_manager.runner.input_batch.req_ids):
740
+ raise NotImplementedError(
741
+ "Llama Guard 4 (JAX) currently supports only text inputs. "
742
+ "Multimodal processing not yet implemented.")
704
743
  else:
705
744
  mm_embeds = []
706
745
 
@@ -733,6 +772,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
733
772
  input_ids,
734
773
  attn_metadata,
735
774
  inputs_embeds,
775
+ input_positions,
736
776
  tuple(self.layer_name_to_kvcache_index.items()),
737
777
  lora_metadata,
738
778
  )
@@ -754,7 +794,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
754
794
  aux_hidden_states=aux_hidden_states,
755
795
  spec_decode_metadata=spec_decode_metadata,
756
796
  kv_connector_output=kv_connector_output,
757
- logits_indices_selector=logits_indices_selector)
797
+ logits_indices_selector=logits_indices_selector,
798
+ padded_num_reqs=padded_num_reqs)
758
799
  return attn_metadata, None
759
800
 
760
801
  def _sample_from_logits(
@@ -768,11 +809,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
768
809
  spec_decode_metadata: Optional[SpecDecodeMetadata],
769
810
  kv_connector_output: Optional[KVConnectorOutput],
770
811
  logits_indices_selector: Optional[List[int]] = None,
812
+ padded_num_reqs: Optional[int] = None,
771
813
  ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
772
- padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
773
- self.input_batch.num_reqs, self.max_num_reqs)
814
+ if padded_num_reqs is None:
815
+ padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
816
+ self.input_batch.num_reqs, self.max_num_reqs)
817
+
818
+ sharding = None
819
+ if self.dp_size > 1:
820
+ sharding = NamedSharding(self.mesh,
821
+ PartitionSpec(ShardingAxisName.ATTN_DATA))
822
+
774
823
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
775
- self.mesh, self.input_batch, padded_num_reqs)
824
+ self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
776
825
  if spec_decode_metadata is None:
777
826
  next_tokens = sample(
778
827
  self.rng_params_for_sampling,
@@ -856,10 +905,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
856
905
 
857
906
  if logprobs is not None:
858
907
  # Map logprobs back to the pre-dp shuffling order
859
- logprobs_lists = logprobs.tolists()
860
- if logits_indices_selector is not None:
861
- logprobs_lists = _reorder_logits_indices(
862
- logprobs_lists, logits_indices_selector)
908
+ logprobs_lists = _jax_logprobs_to_lists(
909
+ logprobs, logits_indices_selector)
863
910
 
864
911
  else:
865
912
  logprobs_lists = None
@@ -929,10 +976,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
929
976
 
930
977
  if logprobs is not None:
931
978
  # Map logprobs back to the pre-dp shuffling order
932
- logprobs_lists = logprobs.tolists()
933
- if logits_indices_selector is not None:
934
- logprobs_lists = _reorder_logits_indices(
935
- logprobs_lists, logits_indices_selector)
979
+ logprobs_lists = _jax_logprobs_to_lists(logprobs,
980
+ logits_indices_selector)
936
981
  else:
937
982
  logprobs_lists = None
938
983
 
@@ -1280,16 +1325,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1280
1325
  mrope_positions = self.mrope_positions_cpu[:, :
1281
1326
  padded_total_num_scheduled_tokens]
1282
1327
 
1283
- block_tables = self.block_table_cpu[:self.max_num_reqs]
1284
- for dp_rank in range(dp_size):
1285
- req_offset = dp_rank * max_num_reqs_per_dp_rank
1286
- _num_reqs = num_req_per_dp_rank[dp_rank]
1287
-
1288
- block_tables[
1289
- req_offset:req_offset + _num_reqs, :self.
1290
- max_num_blocks_per_req] = self.input_batch.block_table[
1291
- 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1292
-
1293
1328
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
1294
1329
  dp_size]
1295
1330
  seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
@@ -1331,20 +1366,59 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1331
1366
  if self.uses_mrope:
1332
1367
  positions = mrope_positions
1333
1368
 
1334
- # Convert block_tables to 1D on cpu.
1335
- block_tables = block_tables.reshape(-1)
1336
-
1337
1369
  query_start_loc_cpu = query_start_loc
1338
1370
  logits_indices_cpu = logits_indices
1339
1371
  seq_lens_cpu = seq_lens
1340
1372
 
1341
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1342
- logits_indices, request_distribution) = device_array(
1373
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1374
+ request_distribution) = device_array(
1343
1375
  self.mesh,
1344
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1345
- logits_indices, request_distribution),
1376
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1377
+ request_distribution),
1346
1378
  sharding=data_parallel_attn_sharding,
1347
1379
  )
1380
+
1381
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1382
+ uniform_attention_metadata: AttentionMetadata = None
1383
+ for kv_cache_gid, kv_cache_group in enumerate(
1384
+ self.kv_cache_config.kv_cache_groups):
1385
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1386
+ max_num_reqs]
1387
+ for dp_rank in range(dp_size):
1388
+ req_offset = dp_rank * max_num_reqs_per_dp_rank
1389
+ _num_reqs = num_req_per_dp_rank[dp_rank]
1390
+
1391
+ block_tables[
1392
+ req_offset:req_offset + _num_reqs, :self.
1393
+ max_num_blocks_per_req] = self.input_batch.block_table[
1394
+ 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1395
+ # Convert block_tables to 1D on cpu.
1396
+ block_tables = block_tables.reshape(-1)
1397
+ block_tables = device_array(
1398
+ self.mesh,
1399
+ (block_tables),
1400
+ sharding=data_parallel_attn_sharding,
1401
+ )
1402
+
1403
+ attention_metadata_gid = AttentionMetadata(
1404
+ input_positions=positions,
1405
+ block_tables=block_tables,
1406
+ seq_lens=seq_lens,
1407
+ query_start_loc=query_start_loc,
1408
+ request_distribution=request_distribution,
1409
+ )
1410
+
1411
+ # This is for making these cpu buffers hidden during tracing
1412
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1413
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1414
+
1415
+ if not self.use_hybrid_kvcache:
1416
+ uniform_attention_metadata = attention_metadata_gid
1417
+ else:
1418
+ for layer_name in kv_cache_group.layer_names:
1419
+ attention_metadata_per_layer[
1420
+ layer_name] = attention_metadata_gid
1421
+
1348
1422
  # Async scheduling: substitute placeholder tokens for DP
1349
1423
  if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
1350
1424
  # Collect all token indices that need substitution across all DP ranks
@@ -1373,25 +1447,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1373
1447
  padded_total_num_scheduled_tokens,
1374
1448
  )
1375
1449
 
1376
- attention_metadata = AttentionMetadata(
1377
- input_positions=positions,
1378
- block_tables=block_tables,
1379
- seq_lens=seq_lens,
1380
- query_start_loc=query_start_loc,
1381
- request_distribution=request_distribution,
1382
- )
1383
-
1384
- # This is for making these cpu buffers hidden during tracing
1385
- attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1386
- attention_metadata.seq_lens_cpu = seq_lens_cpu
1387
-
1450
+ if self.use_hybrid_kvcache:
1451
+ attention_metadata = attention_metadata_per_layer
1452
+ else:
1453
+ attention_metadata = uniform_attention_metadata
1388
1454
  return (
1389
1455
  input_ids,
1456
+ positions,
1390
1457
  attention_metadata,
1391
1458
  sampling_metadata,
1392
1459
  logits_indices,
1393
1460
  spec_decode_metadata,
1394
1461
  logits_indices_selector,
1462
+ padded_num_reqs,
1395
1463
  )
1396
1464
 
1397
1465
  def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
@@ -1492,9 +1560,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1492
1560
  positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
1493
1561
  mrope_positions = self.mrope_positions_cpu[:, :
1494
1562
  padded_total_num_scheduled_tokens]
1495
- block_tables = self.block_table_cpu[:self.max_num_reqs]
1496
- block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
1497
- self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
1498
1563
 
1499
1564
  # TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
1500
1565
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
@@ -1523,16 +1588,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1523
1588
  self.mesh, self.input_batch, padded_num_reqs)
1524
1589
  if self.uses_mrope:
1525
1590
  positions = mrope_positions
1526
-
1527
- # Convert block_tables to 1D on cpu.
1528
- block_tables = block_tables.reshape(-1)
1529
-
1530
1591
  query_start_loc_cpu = query_start_loc
1531
1592
  seq_lens_cpu = seq_lens
1532
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1593
+
1594
+ (input_ids, positions, query_start_loc, seq_lens,
1533
1595
  logits_indices, request_distribution) = device_array(
1534
- self.mesh, (input_ids, positions, block_tables, query_start_loc,
1535
- seq_lens, logits_indices, request_distribution))
1596
+ self.mesh, (input_ids, positions, query_start_loc, seq_lens,
1597
+ logits_indices, request_distribution))
1598
+
1599
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1600
+ uniform_attention_metadata: AttentionMetadata = None
1601
+ for kv_cache_gid, kv_cache_group in enumerate(
1602
+ self.kv_cache_config.kv_cache_groups):
1603
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1604
+ max_num_reqs]
1605
+ block_tables[:num_reqs] = (
1606
+ self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
1607
+ [:num_reqs])
1608
+ # Convert block_tables to 1D on cpu.
1609
+ block_tables = block_tables.reshape(-1)
1610
+ block_tables = device_array(self.mesh, (block_tables))
1611
+
1612
+ attention_metadata_gid = AttentionMetadata(
1613
+ input_positions=positions,
1614
+ block_tables=block_tables,
1615
+ seq_lens=seq_lens,
1616
+ query_start_loc=query_start_loc,
1617
+ request_distribution=request_distribution)
1618
+ # This is for making these cpu buffers hidden during tracing
1619
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1620
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1621
+
1622
+ if not self.use_hybrid_kvcache:
1623
+ # all layers share the same attention metadata
1624
+ uniform_attention_metadata = attention_metadata_gid
1625
+ else:
1626
+ for layer_name in kv_cache_group.layer_names:
1627
+ attention_metadata_per_layer[
1628
+ layer_name] = attention_metadata_gid
1536
1629
 
1537
1630
  if self.scheduler_config.async_scheduling and len(
1538
1631
  token_in_tpu_cur_input_indices) > 0:
@@ -1545,20 +1638,15 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1545
1638
  self.lora_utils.set_active_loras(
1546
1639
  num_scheduled_tokens_per_req, total_num_scheduled_tokens,
1547
1640
  padded_total_num_scheduled_tokens)
1548
-
1549
- attention_metadata = AttentionMetadata(
1550
- input_positions=positions,
1551
- block_tables=block_tables,
1552
- seq_lens=seq_lens,
1553
- query_start_loc=query_start_loc,
1554
- request_distribution=request_distribution)
1555
-
1556
- # This is for making these cpu buffers hidden during tracing
1557
- attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1558
- attention_metadata.seq_lens_cpu = seq_lens_cpu
1559
1641
  logits_indices_selector = None
1560
- return (input_ids, attention_metadata, sampling_metadata,
1561
- logits_indices, spec_decode_metadata, logits_indices_selector)
1642
+
1643
+ if self.use_hybrid_kvcache:
1644
+ attention_metadata = attention_metadata_per_layer
1645
+ else:
1646
+ attention_metadata = uniform_attention_metadata
1647
+ return (input_ids, positions, attention_metadata, sampling_metadata,
1648
+ logits_indices, spec_decode_metadata, logits_indices_selector,
1649
+ padded_num_reqs)
1562
1650
 
1563
1651
  def _get_input_ids_embeds(self, input_ids: jax.Array,
1564
1652
  mm_embeds: list[jax.Array]):
@@ -51,7 +51,8 @@ class Eagle3Proposer:
51
51
  """Loads the draft model."""
52
52
  self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
53
53
  self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
54
- del self.state.model['embed_tokens']
54
+ if 'embed_tokens' in self.state.model:
55
+ del self.state.model['embed_tokens']
55
56
  self.state.model.embed_tokens = target_model.model.embed
56
57
 
57
58
  @functools.partial(jax.jit, static_argnums=(0, ))
tpu_inference/tpu_info.py CHANGED
@@ -3,6 +3,7 @@ import os
3
3
 
4
4
  import requests
5
5
 
6
+ from tpu_inference import envs
6
7
  from tpu_inference.logger import init_logger
7
8
 
8
9
  logger = init_logger(__name__)
@@ -32,14 +33,14 @@ def get_tpu_metadata(key: str = "") -> str:
32
33
 
33
34
 
34
35
  def get_tpu_type() -> str:
35
- tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None)
36
+ tpu_type = envs.TPU_ACCELERATOR_TYPE
36
37
  if tpu_type is None:
37
38
  tpu_type = get_tpu_metadata(key="accelerator-type")
38
39
  return tpu_type
39
40
 
40
41
 
41
42
  def get_node_name() -> str:
42
- tpu_name = os.getenv("TPU_NAME", None)
43
+ tpu_name = envs.TPU_NAME
43
44
  if not tpu_name:
44
45
  tpu_name = get_tpu_metadata(key="instance-id")
45
46
  return tpu_name
@@ -47,7 +48,7 @@ def get_node_name() -> str:
47
48
 
48
49
  def get_node_worker_id() -> int:
49
50
  """For multi-host TPU VM, this returns the worker id for the current node."""
50
- worker_id = os.getenv("TPU_WORKER_ID", None)
51
+ worker_id = envs.TPU_WORKER_ID
51
52
  if worker_id is None:
52
53
  worker_id = get_tpu_metadata(key="agent-worker-number")
53
54
  if worker_id is None:
tpu_inference/utils.py CHANGED
@@ -1,5 +1,4 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
- import os
3
2
  import time
4
3
  from collections import defaultdict
5
4
  from collections.abc import Sequence
@@ -14,8 +13,10 @@ from jax._src import mesh as mesh_lib
14
13
  from jax._src import xla_bridge as xb
15
14
  from jax._src.lib import xla_client as xc
16
15
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
17
- from vllm import envs, utils
16
+ from vllm import envs as vllm_envs
17
+ from vllm import utils
18
18
 
19
+ from tpu_inference import envs
19
20
  from tpu_inference.logger import init_logger
20
21
 
21
22
  GBYTES = 1024 * 1024 * 1024
@@ -57,10 +58,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
57
58
 
58
59
  def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
59
60
  usage = []
60
- if envs.VLLM_TPU_USING_PATHWAYS:
61
+ if vllm_envs.VLLM_TPU_USING_PATHWAYS:
61
62
  return pathways_hbm_usage_gb(devices)
62
63
 
63
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
64
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
64
65
  if multihost_backend == "ray":
65
66
  # MemoryStats is only supported for addressable PjRt devices.
66
67
  # Assume all the devices have similar memory usage for now.