tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (49) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tpu_inference/__init__.py +22 -3
  5. tpu_inference/core/disagg_utils.py +6 -8
  6. tpu_inference/distributed/tpu_connector.py +2 -3
  7. tpu_inference/distributed/utils.py +3 -2
  8. tpu_inference/envs.py +1 -1
  9. tpu_inference/executors/ray_distributed_executor.py +27 -11
  10. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  11. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  12. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
  13. tpu_inference/layers/common/attention_interface.py +7 -1
  14. tpu_inference/layers/common/sharding.py +2 -1
  15. tpu_inference/layers/vllm/fused_moe.py +74 -25
  16. tpu_inference/layers/vllm/quantization/common.py +6 -1
  17. tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
  18. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  19. tpu_inference/layers/vllm/sharding.py +2 -2
  20. tpu_inference/lora/torch_punica_tpu.py +1 -2
  21. tpu_inference/models/common/model_loader.py +43 -11
  22. tpu_inference/models/jax/llama3.py +2 -1
  23. tpu_inference/models/jax/llama_eagle3.py +8 -5
  24. tpu_inference/models/jax/llama_guard_4.py +361 -0
  25. tpu_inference/models/jax/qwen2.py +2 -1
  26. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  27. tpu_inference/models/jax/qwen3.py +2 -1
  28. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  29. tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
  30. tpu_inference/platforms/tpu_platform.py +15 -2
  31. tpu_inference/runner/compilation_manager.py +58 -33
  32. tpu_inference/runner/kv_cache_manager.py +9 -3
  33. tpu_inference/runner/structured_decoding_manager.py +2 -3
  34. tpu_inference/runner/tpu_runner.py +203 -102
  35. tpu_inference/spec_decode/jax/eagle3.py +19 -2
  36. tpu_inference/tpu_info.py +4 -3
  37. tpu_inference/utils.py +5 -4
  38. tpu_inference/worker/tpu_worker.py +160 -23
  39. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
  40. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
  41. tpu_inference/mock/__init__.py +0 -0
  42. tpu_inference/mock/vllm_config_utils.py +0 -28
  43. tpu_inference/mock/vllm_envs.py +0 -1219
  44. tpu_inference/mock/vllm_logger.py +0 -212
  45. tpu_inference/mock/vllm_logging_utils.py +0 -15
  46. tpu_inference/models/jax/phi3.py +0 -376
  47. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
  48. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
  49. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
@@ -153,6 +153,7 @@ class ExecuteModelState:
153
153
  spec_decode_metadata: Optional[SpecDecodeMetadata]
154
154
  kv_connector_output: Optional[KVConnectorOutput]
155
155
  logits_indices_selector: Optional[List[int]] = None
156
+ padded_num_reqs: Optional[int] = None
156
157
 
157
158
 
158
159
  @functools.partial(jax.jit, donate_argnums=(0, 1, 2))
@@ -190,18 +191,28 @@ def _substitute_placeholder_token(
190
191
  return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
191
192
 
192
193
 
193
- def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
194
+ def _jax_logprobs_to_lists(logprobs_tensors,
195
+ logits_indices_selector=None,
196
+ cu_num_generated_tokens=None):
197
+ """Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
198
+ log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
199
+ logprobs_list = logprobs_tensors.logprobs.tolist()
200
+ selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
201
+
202
+ if logits_indices_selector is not None:
203
+ log_token_ids_list = [
204
+ log_token_ids_list[i] for i in logits_indices_selector
205
+ ]
206
+ logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
207
+ selected_token_ranks_list = [
208
+ selected_token_ranks_list[i] for i in logits_indices_selector
209
+ ]
210
+
194
211
  return LogprobsLists(
195
- logprob_token_ids=[
196
- logprobs_lists.logprob_token_ids[i]
197
- for i in logits_indices_selector
198
- ],
199
- logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
200
- sampled_token_ranks=[
201
- logprobs_lists.sampled_token_ranks[i]
202
- for i in logits_indices_selector
203
- ],
204
- cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
212
+ logprob_token_ids=np.asarray(log_token_ids_list),
213
+ logprobs=np.asarray(logprobs_list),
214
+ sampled_token_ranks=np.asarray(selected_token_ranks_list),
215
+ cu_num_generated_tokens=cu_num_generated_tokens,
205
216
  )
206
217
 
207
218
 
@@ -211,6 +222,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
211
222
  self,
212
223
  vllm_config: VllmConfig,
213
224
  devices: List[Any],
225
+ rank: int = 0,
226
+ is_first_rank: bool = True,
227
+ is_last_rank: bool = True,
214
228
  ):
215
229
  self.vllm_config = vllm_config
216
230
  self.model_config = vllm_config.model_config
@@ -423,8 +437,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
423
437
 
424
438
  self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
425
439
  self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
426
- self.block_table_cpu = np.zeros(
427
- (self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
440
+ # Note: self.input_batch and self.block_tables_cpu are both initialized
441
+ # with only 1 block_size. For hybrid kv cache, it will be re-init
442
+ # in kv_cache_manager's maybe_reinitialize_input_batch.
443
+ self.block_tables_cpu = [
444
+ np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
445
+ dtype=np.int32)
446
+ ]
447
+
428
448
  self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
429
449
  dtype=np.int32)
430
450
  self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
@@ -458,9 +478,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
458
478
 
459
479
  # tensors for structured decoding
460
480
  self.vocab_size = self.model_config.get_vocab_size()
461
- if self.lora_config is not None:
462
- # lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
463
- self.vocab_size += self.lora_config.lora_extra_vocab_size
464
481
  self.grammar_bitmask_cpu = np.zeros(
465
482
  (self.max_num_reqs, cdiv(self.vocab_size, 32)),
466
483
  dtype=np.int32,
@@ -505,9 +522,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
505
522
 
506
523
  self.rng_params_for_sampling = nnx.Rngs(
507
524
  jax.random.key(self.model_config.seed)).params()
508
- self.is_multimodal_model = (self.model_config.is_multimodal_model
509
- and self.get_multimodal_embeddings_fn
510
- is not None)
525
+ self.is_multimodal_model = (
526
+ self.model_config.is_multimodal_model
527
+ and self.get_multimodal_embeddings_fn is not None and hasattr(
528
+ self.model_config.hf_config, "architectures"
529
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
530
+ and len(self.model_config.hf_config.architectures) >= 1
531
+ and self.model_config.hf_config.architectures[0]
532
+ != "Llama4ForConditionalGeneration")
511
533
 
512
534
  logger.info(f"Init model | "
513
535
  f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
@@ -520,6 +542,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
520
542
 
521
543
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
522
544
  self.kv_cache_config = kv_cache_config
545
+ self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
523
546
  self.kv_caches = []
524
547
  self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
525
548
  if has_kv_transfer_group():
@@ -550,16 +573,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
550
573
 
551
574
  (scheduler_output, attn_metadata, input_ids, hidden_states, logits,
552
575
  aux_hidden_states, spec_decode_metadata, kv_connector_output,
553
- logits_indices_selector) = (
554
- self.execute_model_state.scheduler_output,
555
- self.execute_model_state.attn_metadata,
556
- self.execute_model_state.input_ids,
557
- self.execute_model_state.hidden_states,
558
- self.execute_model_state.logits,
559
- self.execute_model_state.aux_hidden_states,
560
- self.execute_model_state.spec_decode_metadata,
561
- self.execute_model_state.kv_connector_output,
562
- self.execute_model_state.logits_indices_selector)
576
+ logits_indices_selector,
577
+ padded_num_reqs) = (self.execute_model_state.scheduler_output,
578
+ self.execute_model_state.attn_metadata,
579
+ self.execute_model_state.input_ids,
580
+ self.execute_model_state.hidden_states,
581
+ self.execute_model_state.logits,
582
+ self.execute_model_state.aux_hidden_states,
583
+ self.execute_model_state.spec_decode_metadata,
584
+ self.execute_model_state.kv_connector_output,
585
+ self.execute_model_state.logits_indices_selector,
586
+ self.execute_model_state.padded_num_reqs)
563
587
  self.execute_model_state = None
564
588
 
565
589
  if grammar_output is not None:
@@ -573,12 +597,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
573
597
  logits,
574
598
  arange,
575
599
  )
576
- return self._sample_from_logits(scheduler_output, attn_metadata,
577
- input_ids, hidden_states, logits,
578
- aux_hidden_states,
579
- spec_decode_metadata,
580
- kv_connector_output,
581
- logits_indices_selector)
600
+ return self._sample_from_logits(
601
+ scheduler_output, attn_metadata, input_ids, hidden_states, logits,
602
+ aux_hidden_states, spec_decode_metadata, kv_connector_output,
603
+ logits_indices_selector, padded_num_reqs)
582
604
 
583
605
  def _modify_prev_results(self):
584
606
  # If copy to host has not been done, we just wait.
@@ -687,13 +709,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
687
709
  # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
688
710
  (
689
711
  input_ids,
712
+ input_positions,
690
713
  attn_metadata,
691
714
  _,
692
715
  logits_indices,
693
716
  spec_decode_metadata,
694
717
  logits_indices_selector,
718
+ padded_num_reqs,
695
719
  ) = self._prepare_inputs(scheduler_output)
696
720
 
721
+ is_llama_guard_4 = (
722
+ hasattr(
723
+ self.model_config.hf_config, "architectures"
724
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
725
+ and len(self.model_config.hf_config.architectures) >= 1
726
+ and self.model_config.hf_config.architectures[0]
727
+ == "Llama4ForConditionalGeneration")
728
+
697
729
  # multi-modal support
698
730
  if self.is_multimodal_model:
699
731
  # Run the multimodal encoder if any.
@@ -701,6 +733,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
701
733
  self.mm_manager.execute_mm_encoder(scheduler_output)
702
734
  mm_embeds = self.mm_manager.gather_mm_embeddings(
703
735
  scheduler_output, input_ids.shape[0])
736
+ #TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
737
+ elif is_llama_guard_4 and any(
738
+ self.mm_manager.runner.requests[req_id].mm_features
739
+ for req_id in self.mm_manager.runner.input_batch.req_ids):
740
+ raise NotImplementedError(
741
+ "Llama Guard 4 (JAX) currently supports only text inputs. "
742
+ "Multimodal processing not yet implemented.")
704
743
  else:
705
744
  mm_embeds = []
706
745
 
@@ -733,6 +772,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
733
772
  input_ids,
734
773
  attn_metadata,
735
774
  inputs_embeds,
775
+ input_positions,
736
776
  tuple(self.layer_name_to_kvcache_index.items()),
737
777
  lora_metadata,
738
778
  )
@@ -754,7 +794,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
754
794
  aux_hidden_states=aux_hidden_states,
755
795
  spec_decode_metadata=spec_decode_metadata,
756
796
  kv_connector_output=kv_connector_output,
757
- logits_indices_selector=logits_indices_selector)
797
+ logits_indices_selector=logits_indices_selector,
798
+ padded_num_reqs=padded_num_reqs)
758
799
  return attn_metadata, None
759
800
 
760
801
  def _sample_from_logits(
@@ -768,23 +809,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
768
809
  spec_decode_metadata: Optional[SpecDecodeMetadata],
769
810
  kv_connector_output: Optional[KVConnectorOutput],
770
811
  logits_indices_selector: Optional[List[int]] = None,
812
+ padded_num_reqs: Optional[int] = None,
771
813
  ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
772
- padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
773
- self.input_batch.num_reqs, self.max_num_reqs)
814
+ if padded_num_reqs is None:
815
+ padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
816
+ self.input_batch.num_reqs, self.max_num_reqs)
817
+
818
+ sharding = None
819
+ if self.dp_size > 1:
820
+ sharding = NamedSharding(self.mesh,
821
+ PartitionSpec(ShardingAxisName.ATTN_DATA))
822
+
774
823
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
775
- self.mesh, self.input_batch, padded_num_reqs)
824
+ self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
825
+
826
+ # TODO(pooyam): Should we move this to `_prepare_inputs`?
827
+ if tpu_sampling_metadata.do_sampling:
828
+ self.rng_params_for_sampling, step_rng = jax.random.split(
829
+ self.rng_params_for_sampling)
830
+ else:
831
+ step_rng = self.rng_params_for_sampling
832
+
776
833
  if spec_decode_metadata is None:
777
834
  next_tokens = sample(
778
- self.rng_params_for_sampling,
835
+ step_rng,
779
836
  self.mesh,
780
837
  logits,
781
838
  tpu_sampling_metadata,
782
839
  )
783
840
  else:
841
+ if tpu_sampling_metadata.do_sampling:
842
+ bonus_rng, rejection_rng = jax.random.split(step_rng)
843
+ else:
844
+ bonus_rng = step_rng
845
+ rejection_rng = step_rng
784
846
  bonus_logits = self._select_from_array_fn(
785
847
  logits, spec_decode_metadata.bonus_logits_indices)
786
848
  bonus_token_ids = sample(
787
- self.rng_params_for_sampling,
849
+ bonus_rng,
788
850
  self.mesh,
789
851
  bonus_logits,
790
852
  tpu_sampling_metadata,
@@ -798,7 +860,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
798
860
  target_logits=target_logits,
799
861
  bonus_token_ids=bonus_token_ids,
800
862
  sampling_metadata=tpu_sampling_metadata,
801
- key=self.rng_params_for_sampling,
863
+ key=rejection_rng,
802
864
  )
803
865
 
804
866
  if tpu_sampling_metadata.logprobs:
@@ -856,10 +918,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
856
918
 
857
919
  if logprobs is not None:
858
920
  # Map logprobs back to the pre-dp shuffling order
859
- logprobs_lists = logprobs.tolists()
860
- if logits_indices_selector is not None:
861
- logprobs_lists = _reorder_logits_indices(
862
- logprobs_lists, logits_indices_selector)
921
+ logprobs_lists = _jax_logprobs_to_lists(
922
+ logprobs, logits_indices_selector)
863
923
 
864
924
  else:
865
925
  logprobs_lists = None
@@ -929,10 +989,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
929
989
 
930
990
  if logprobs is not None:
931
991
  # Map logprobs back to the pre-dp shuffling order
932
- logprobs_lists = logprobs.tolists()
933
- if logits_indices_selector is not None:
934
- logprobs_lists = _reorder_logits_indices(
935
- logprobs_lists, logits_indices_selector)
992
+ logprobs_lists = _jax_logprobs_to_lists(logprobs,
993
+ logits_indices_selector)
936
994
  else:
937
995
  logprobs_lists = None
938
996
 
@@ -1280,16 +1338,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1280
1338
  mrope_positions = self.mrope_positions_cpu[:, :
1281
1339
  padded_total_num_scheduled_tokens]
1282
1340
 
1283
- block_tables = self.block_table_cpu[:self.max_num_reqs]
1284
- for dp_rank in range(dp_size):
1285
- req_offset = dp_rank * max_num_reqs_per_dp_rank
1286
- _num_reqs = num_req_per_dp_rank[dp_rank]
1287
-
1288
- block_tables[
1289
- req_offset:req_offset + _num_reqs, :self.
1290
- max_num_blocks_per_req] = self.input_batch.block_table[
1291
- 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1292
-
1293
1341
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
1294
1342
  dp_size]
1295
1343
  seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
@@ -1331,20 +1379,59 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1331
1379
  if self.uses_mrope:
1332
1380
  positions = mrope_positions
1333
1381
 
1334
- # Convert block_tables to 1D on cpu.
1335
- block_tables = block_tables.reshape(-1)
1336
-
1337
1382
  query_start_loc_cpu = query_start_loc
1338
1383
  logits_indices_cpu = logits_indices
1339
1384
  seq_lens_cpu = seq_lens
1340
1385
 
1341
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1342
- logits_indices, request_distribution) = device_array(
1386
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1387
+ request_distribution) = device_array(
1343
1388
  self.mesh,
1344
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1345
- logits_indices, request_distribution),
1389
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1390
+ request_distribution),
1346
1391
  sharding=data_parallel_attn_sharding,
1347
1392
  )
1393
+
1394
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1395
+ uniform_attention_metadata: AttentionMetadata = None
1396
+ for kv_cache_gid, kv_cache_group in enumerate(
1397
+ self.kv_cache_config.kv_cache_groups):
1398
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1399
+ max_num_reqs]
1400
+ for dp_rank in range(dp_size):
1401
+ req_offset = dp_rank * max_num_reqs_per_dp_rank
1402
+ _num_reqs = num_req_per_dp_rank[dp_rank]
1403
+
1404
+ block_tables[
1405
+ req_offset:req_offset + _num_reqs, :self.
1406
+ max_num_blocks_per_req] = self.input_batch.block_table[
1407
+ 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1408
+ # Convert block_tables to 1D on cpu.
1409
+ block_tables = block_tables.reshape(-1)
1410
+ block_tables = device_array(
1411
+ self.mesh,
1412
+ (block_tables),
1413
+ sharding=data_parallel_attn_sharding,
1414
+ )
1415
+
1416
+ attention_metadata_gid = AttentionMetadata(
1417
+ input_positions=positions,
1418
+ block_tables=block_tables,
1419
+ seq_lens=seq_lens,
1420
+ query_start_loc=query_start_loc,
1421
+ request_distribution=request_distribution,
1422
+ )
1423
+
1424
+ # This is for making these cpu buffers hidden during tracing
1425
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1426
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1427
+
1428
+ if not self.use_hybrid_kvcache:
1429
+ uniform_attention_metadata = attention_metadata_gid
1430
+ else:
1431
+ for layer_name in kv_cache_group.layer_names:
1432
+ attention_metadata_per_layer[
1433
+ layer_name] = attention_metadata_gid
1434
+
1348
1435
  # Async scheduling: substitute placeholder tokens for DP
1349
1436
  if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
1350
1437
  # Collect all token indices that need substitution across all DP ranks
@@ -1373,25 +1460,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1373
1460
  padded_total_num_scheduled_tokens,
1374
1461
  )
1375
1462
 
1376
- attention_metadata = AttentionMetadata(
1377
- input_positions=positions,
1378
- block_tables=block_tables,
1379
- seq_lens=seq_lens,
1380
- query_start_loc=query_start_loc,
1381
- request_distribution=request_distribution,
1382
- )
1383
-
1384
- # This is for making these cpu buffers hidden during tracing
1385
- attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1386
- attention_metadata.seq_lens_cpu = seq_lens_cpu
1387
-
1463
+ if self.use_hybrid_kvcache:
1464
+ attention_metadata = attention_metadata_per_layer
1465
+ else:
1466
+ attention_metadata = uniform_attention_metadata
1388
1467
  return (
1389
1468
  input_ids,
1469
+ positions,
1390
1470
  attention_metadata,
1391
1471
  sampling_metadata,
1392
1472
  logits_indices,
1393
1473
  spec_decode_metadata,
1394
1474
  logits_indices_selector,
1475
+ padded_num_reqs,
1395
1476
  )
1396
1477
 
1397
1478
  def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
@@ -1492,9 +1573,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1492
1573
  positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
1493
1574
  mrope_positions = self.mrope_positions_cpu[:, :
1494
1575
  padded_total_num_scheduled_tokens]
1495
- block_tables = self.block_table_cpu[:self.max_num_reqs]
1496
- block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
1497
- self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
1498
1576
 
1499
1577
  # TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
1500
1578
  query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
@@ -1523,16 +1601,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1523
1601
  self.mesh, self.input_batch, padded_num_reqs)
1524
1602
  if self.uses_mrope:
1525
1603
  positions = mrope_positions
1526
-
1527
- # Convert block_tables to 1D on cpu.
1528
- block_tables = block_tables.reshape(-1)
1529
-
1530
1604
  query_start_loc_cpu = query_start_loc
1531
1605
  seq_lens_cpu = seq_lens
1532
- (input_ids, positions, block_tables, query_start_loc, seq_lens,
1606
+
1607
+ (input_ids, positions, query_start_loc, seq_lens,
1533
1608
  logits_indices, request_distribution) = device_array(
1534
- self.mesh, (input_ids, positions, block_tables, query_start_loc,
1535
- seq_lens, logits_indices, request_distribution))
1609
+ self.mesh, (input_ids, positions, query_start_loc, seq_lens,
1610
+ logits_indices, request_distribution))
1611
+
1612
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1613
+ uniform_attention_metadata: AttentionMetadata = None
1614
+ for kv_cache_gid, kv_cache_group in enumerate(
1615
+ self.kv_cache_config.kv_cache_groups):
1616
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1617
+ max_num_reqs]
1618
+ block_tables[:num_reqs] = (
1619
+ self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
1620
+ [:num_reqs])
1621
+ # Convert block_tables to 1D on cpu.
1622
+ block_tables = block_tables.reshape(-1)
1623
+ block_tables = device_array(self.mesh, (block_tables))
1624
+
1625
+ attention_metadata_gid = AttentionMetadata(
1626
+ input_positions=positions,
1627
+ block_tables=block_tables,
1628
+ seq_lens=seq_lens,
1629
+ query_start_loc=query_start_loc,
1630
+ request_distribution=request_distribution)
1631
+ # This is for making these cpu buffers hidden during tracing
1632
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1633
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1634
+
1635
+ if not self.use_hybrid_kvcache:
1636
+ # all layers share the same attention metadata
1637
+ uniform_attention_metadata = attention_metadata_gid
1638
+ else:
1639
+ for layer_name in kv_cache_group.layer_names:
1640
+ attention_metadata_per_layer[
1641
+ layer_name] = attention_metadata_gid
1536
1642
 
1537
1643
  if self.scheduler_config.async_scheduling and len(
1538
1644
  token_in_tpu_cur_input_indices) > 0:
@@ -1545,20 +1651,15 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1545
1651
  self.lora_utils.set_active_loras(
1546
1652
  num_scheduled_tokens_per_req, total_num_scheduled_tokens,
1547
1653
  padded_total_num_scheduled_tokens)
1548
-
1549
- attention_metadata = AttentionMetadata(
1550
- input_positions=positions,
1551
- block_tables=block_tables,
1552
- seq_lens=seq_lens,
1553
- query_start_loc=query_start_loc,
1554
- request_distribution=request_distribution)
1555
-
1556
- # This is for making these cpu buffers hidden during tracing
1557
- attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1558
- attention_metadata.seq_lens_cpu = seq_lens_cpu
1559
1654
  logits_indices_selector = None
1560
- return (input_ids, attention_metadata, sampling_metadata,
1561
- logits_indices, spec_decode_metadata, logits_indices_selector)
1655
+
1656
+ if self.use_hybrid_kvcache:
1657
+ attention_metadata = attention_metadata_per_layer
1658
+ else:
1659
+ attention_metadata = uniform_attention_metadata
1660
+ return (input_ids, positions, attention_metadata, sampling_metadata,
1661
+ logits_indices, spec_decode_metadata, logits_indices_selector,
1662
+ padded_num_reqs)
1562
1663
 
1563
1664
  def _get_input_ids_embeds(self, input_ids: jax.Array,
1564
1665
  mm_embeds: list[jax.Array]):
@@ -9,10 +9,13 @@ import numpy as np
9
9
  from vllm.config import VllmConfig
10
10
 
11
11
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
+ from tpu_inference.logger import init_logger
12
13
  from tpu_inference.models.common.model_loader import get_model
13
14
  from tpu_inference.runner import utils as runner_utils
14
15
  from tpu_inference.utils import device_array
15
16
 
17
+ logger = init_logger(__name__)
18
+
16
19
 
17
20
  class Eagle3Proposer:
18
21
  """A proposer for speculative decoding using the Eagle3 method.
@@ -51,8 +54,22 @@ class Eagle3Proposer:
51
54
  """Loads the draft model."""
52
55
  self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
53
56
  self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
54
- del self.state.model['embed_tokens']
55
- self.state.model.embed_tokens = target_model.model.embed
57
+
58
+ draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
59
+ if draft_embed_tokens is None or ~jnp.any(
60
+ draft_embed_tokens.embedding):
61
+ logger.info(
62
+ "Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
63
+ )
64
+ self.state.model.embed_tokens = target_model.model.embed
65
+ elif jnp.array_equal(draft_embed_tokens.embedding,
66
+ target_model.model.embed.embedding):
67
+ logger.info(
68
+ "Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
69
+ )
70
+ self.state.model.embed_tokens = target_model.model.embed
71
+ else:
72
+ logger.info("Draft model has its own embed_tokens.")
56
73
 
57
74
  @functools.partial(jax.jit, static_argnums=(0, ))
58
75
  def _prepare_input_ids(
tpu_inference/tpu_info.py CHANGED
@@ -3,6 +3,7 @@ import os
3
3
 
4
4
  import requests
5
5
 
6
+ from tpu_inference import envs
6
7
  from tpu_inference.logger import init_logger
7
8
 
8
9
  logger = init_logger(__name__)
@@ -32,14 +33,14 @@ def get_tpu_metadata(key: str = "") -> str:
32
33
 
33
34
 
34
35
  def get_tpu_type() -> str:
35
- tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None)
36
+ tpu_type = envs.TPU_ACCELERATOR_TYPE
36
37
  if tpu_type is None:
37
38
  tpu_type = get_tpu_metadata(key="accelerator-type")
38
39
  return tpu_type
39
40
 
40
41
 
41
42
  def get_node_name() -> str:
42
- tpu_name = os.getenv("TPU_NAME", None)
43
+ tpu_name = envs.TPU_NAME
43
44
  if not tpu_name:
44
45
  tpu_name = get_tpu_metadata(key="instance-id")
45
46
  return tpu_name
@@ -47,7 +48,7 @@ def get_node_name() -> str:
47
48
 
48
49
  def get_node_worker_id() -> int:
49
50
  """For multi-host TPU VM, this returns the worker id for the current node."""
50
- worker_id = os.getenv("TPU_WORKER_ID", None)
51
+ worker_id = envs.TPU_WORKER_ID
51
52
  if worker_id is None:
52
53
  worker_id = get_tpu_metadata(key="agent-worker-number")
53
54
  if worker_id is None:
tpu_inference/utils.py CHANGED
@@ -1,5 +1,4 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
- import os
3
2
  import time
4
3
  from collections import defaultdict
5
4
  from collections.abc import Sequence
@@ -14,8 +13,10 @@ from jax._src import mesh as mesh_lib
14
13
  from jax._src import xla_bridge as xb
15
14
  from jax._src.lib import xla_client as xc
16
15
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
17
- from vllm import envs, utils
16
+ from vllm import envs as vllm_envs
17
+ from vllm import utils
18
18
 
19
+ from tpu_inference import envs
19
20
  from tpu_inference.logger import init_logger
20
21
 
21
22
  GBYTES = 1024 * 1024 * 1024
@@ -57,10 +58,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
57
58
 
58
59
  def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
59
60
  usage = []
60
- if envs.VLLM_TPU_USING_PATHWAYS:
61
+ if vllm_envs.VLLM_TPU_USING_PATHWAYS:
61
62
  return pathways_hbm_usage_gb(devices)
62
63
 
63
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
64
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
64
65
  if multihost_backend == "ray":
65
66
  # MemoryStats is only supported for addressable PjRt devices.
66
67
  # Assume all the devices have similar memory usage for now.