tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +2 -1
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +43 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
- tpu_inference/platforms/tpu_platform.py +15 -2
- tpu_inference/runner/compilation_manager.py +58 -33
- tpu_inference/runner/kv_cache_manager.py +9 -3
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +203 -102
- tpu_inference/spec_decode/jax/eagle3.py +19 -2
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +160 -23
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
|
@@ -153,6 +153,7 @@ class ExecuteModelState:
|
|
|
153
153
|
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
|
154
154
|
kv_connector_output: Optional[KVConnectorOutput]
|
|
155
155
|
logits_indices_selector: Optional[List[int]] = None
|
|
156
|
+
padded_num_reqs: Optional[int] = None
|
|
156
157
|
|
|
157
158
|
|
|
158
159
|
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
|
|
@@ -190,18 +191,28 @@ def _substitute_placeholder_token(
|
|
|
190
191
|
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
|
|
191
192
|
|
|
192
193
|
|
|
193
|
-
def
|
|
194
|
+
def _jax_logprobs_to_lists(logprobs_tensors,
|
|
195
|
+
logits_indices_selector=None,
|
|
196
|
+
cu_num_generated_tokens=None):
|
|
197
|
+
"""Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
|
|
198
|
+
log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
|
|
199
|
+
logprobs_list = logprobs_tensors.logprobs.tolist()
|
|
200
|
+
selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
|
|
201
|
+
|
|
202
|
+
if logits_indices_selector is not None:
|
|
203
|
+
log_token_ids_list = [
|
|
204
|
+
log_token_ids_list[i] for i in logits_indices_selector
|
|
205
|
+
]
|
|
206
|
+
logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
|
|
207
|
+
selected_token_ranks_list = [
|
|
208
|
+
selected_token_ranks_list[i] for i in logits_indices_selector
|
|
209
|
+
]
|
|
210
|
+
|
|
194
211
|
return LogprobsLists(
|
|
195
|
-
logprob_token_ids=
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
|
|
200
|
-
sampled_token_ranks=[
|
|
201
|
-
logprobs_lists.sampled_token_ranks[i]
|
|
202
|
-
for i in logits_indices_selector
|
|
203
|
-
],
|
|
204
|
-
cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
|
|
212
|
+
logprob_token_ids=np.asarray(log_token_ids_list),
|
|
213
|
+
logprobs=np.asarray(logprobs_list),
|
|
214
|
+
sampled_token_ranks=np.asarray(selected_token_ranks_list),
|
|
215
|
+
cu_num_generated_tokens=cu_num_generated_tokens,
|
|
205
216
|
)
|
|
206
217
|
|
|
207
218
|
|
|
@@ -211,6 +222,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
211
222
|
self,
|
|
212
223
|
vllm_config: VllmConfig,
|
|
213
224
|
devices: List[Any],
|
|
225
|
+
rank: int = 0,
|
|
226
|
+
is_first_rank: bool = True,
|
|
227
|
+
is_last_rank: bool = True,
|
|
214
228
|
):
|
|
215
229
|
self.vllm_config = vllm_config
|
|
216
230
|
self.model_config = vllm_config.model_config
|
|
@@ -423,8 +437,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
423
437
|
|
|
424
438
|
self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
425
439
|
self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
426
|
-
self.
|
|
427
|
-
|
|
440
|
+
# Note: self.input_batch and self.block_tables_cpu are both initialized
|
|
441
|
+
# with only 1 block_size. For hybrid kv cache, it will be re-init
|
|
442
|
+
# in kv_cache_manager's maybe_reinitialize_input_batch.
|
|
443
|
+
self.block_tables_cpu = [
|
|
444
|
+
np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
|
|
445
|
+
dtype=np.int32)
|
|
446
|
+
]
|
|
447
|
+
|
|
428
448
|
self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
|
|
429
449
|
dtype=np.int32)
|
|
430
450
|
self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
|
|
@@ -458,9 +478,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
458
478
|
|
|
459
479
|
# tensors for structured decoding
|
|
460
480
|
self.vocab_size = self.model_config.get_vocab_size()
|
|
461
|
-
if self.lora_config is not None:
|
|
462
|
-
# lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
|
|
463
|
-
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
|
464
481
|
self.grammar_bitmask_cpu = np.zeros(
|
|
465
482
|
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
|
|
466
483
|
dtype=np.int32,
|
|
@@ -505,9 +522,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
505
522
|
|
|
506
523
|
self.rng_params_for_sampling = nnx.Rngs(
|
|
507
524
|
jax.random.key(self.model_config.seed)).params()
|
|
508
|
-
self.is_multimodal_model = (
|
|
509
|
-
|
|
510
|
-
|
|
525
|
+
self.is_multimodal_model = (
|
|
526
|
+
self.model_config.is_multimodal_model
|
|
527
|
+
and self.get_multimodal_embeddings_fn is not None and hasattr(
|
|
528
|
+
self.model_config.hf_config, "architectures"
|
|
529
|
+
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
530
|
+
and len(self.model_config.hf_config.architectures) >= 1
|
|
531
|
+
and self.model_config.hf_config.architectures[0]
|
|
532
|
+
!= "Llama4ForConditionalGeneration")
|
|
511
533
|
|
|
512
534
|
logger.info(f"Init model | "
|
|
513
535
|
f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
|
|
@@ -520,6 +542,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
520
542
|
|
|
521
543
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
522
544
|
self.kv_cache_config = kv_cache_config
|
|
545
|
+
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
|
|
523
546
|
self.kv_caches = []
|
|
524
547
|
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
|
|
525
548
|
if has_kv_transfer_group():
|
|
@@ -550,16 +573,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
550
573
|
|
|
551
574
|
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
552
575
|
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
553
|
-
logits_indices_selector
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
576
|
+
logits_indices_selector,
|
|
577
|
+
padded_num_reqs) = (self.execute_model_state.scheduler_output,
|
|
578
|
+
self.execute_model_state.attn_metadata,
|
|
579
|
+
self.execute_model_state.input_ids,
|
|
580
|
+
self.execute_model_state.hidden_states,
|
|
581
|
+
self.execute_model_state.logits,
|
|
582
|
+
self.execute_model_state.aux_hidden_states,
|
|
583
|
+
self.execute_model_state.spec_decode_metadata,
|
|
584
|
+
self.execute_model_state.kv_connector_output,
|
|
585
|
+
self.execute_model_state.logits_indices_selector,
|
|
586
|
+
self.execute_model_state.padded_num_reqs)
|
|
563
587
|
self.execute_model_state = None
|
|
564
588
|
|
|
565
589
|
if grammar_output is not None:
|
|
@@ -573,12 +597,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
573
597
|
logits,
|
|
574
598
|
arange,
|
|
575
599
|
)
|
|
576
|
-
return self._sample_from_logits(
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
kv_connector_output,
|
|
581
|
-
logits_indices_selector)
|
|
600
|
+
return self._sample_from_logits(
|
|
601
|
+
scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
602
|
+
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
603
|
+
logits_indices_selector, padded_num_reqs)
|
|
582
604
|
|
|
583
605
|
def _modify_prev_results(self):
|
|
584
606
|
# If copy to host has not been done, we just wait.
|
|
@@ -687,13 +709,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
687
709
|
# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
|
|
688
710
|
(
|
|
689
711
|
input_ids,
|
|
712
|
+
input_positions,
|
|
690
713
|
attn_metadata,
|
|
691
714
|
_,
|
|
692
715
|
logits_indices,
|
|
693
716
|
spec_decode_metadata,
|
|
694
717
|
logits_indices_selector,
|
|
718
|
+
padded_num_reqs,
|
|
695
719
|
) = self._prepare_inputs(scheduler_output)
|
|
696
720
|
|
|
721
|
+
is_llama_guard_4 = (
|
|
722
|
+
hasattr(
|
|
723
|
+
self.model_config.hf_config, "architectures"
|
|
724
|
+
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
725
|
+
and len(self.model_config.hf_config.architectures) >= 1
|
|
726
|
+
and self.model_config.hf_config.architectures[0]
|
|
727
|
+
== "Llama4ForConditionalGeneration")
|
|
728
|
+
|
|
697
729
|
# multi-modal support
|
|
698
730
|
if self.is_multimodal_model:
|
|
699
731
|
# Run the multimodal encoder if any.
|
|
@@ -701,6 +733,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
701
733
|
self.mm_manager.execute_mm_encoder(scheduler_output)
|
|
702
734
|
mm_embeds = self.mm_manager.gather_mm_embeddings(
|
|
703
735
|
scheduler_output, input_ids.shape[0])
|
|
736
|
+
#TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
|
|
737
|
+
elif is_llama_guard_4 and any(
|
|
738
|
+
self.mm_manager.runner.requests[req_id].mm_features
|
|
739
|
+
for req_id in self.mm_manager.runner.input_batch.req_ids):
|
|
740
|
+
raise NotImplementedError(
|
|
741
|
+
"Llama Guard 4 (JAX) currently supports only text inputs. "
|
|
742
|
+
"Multimodal processing not yet implemented.")
|
|
704
743
|
else:
|
|
705
744
|
mm_embeds = []
|
|
706
745
|
|
|
@@ -733,6 +772,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
733
772
|
input_ids,
|
|
734
773
|
attn_metadata,
|
|
735
774
|
inputs_embeds,
|
|
775
|
+
input_positions,
|
|
736
776
|
tuple(self.layer_name_to_kvcache_index.items()),
|
|
737
777
|
lora_metadata,
|
|
738
778
|
)
|
|
@@ -754,7 +794,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
754
794
|
aux_hidden_states=aux_hidden_states,
|
|
755
795
|
spec_decode_metadata=spec_decode_metadata,
|
|
756
796
|
kv_connector_output=kv_connector_output,
|
|
757
|
-
logits_indices_selector=logits_indices_selector
|
|
797
|
+
logits_indices_selector=logits_indices_selector,
|
|
798
|
+
padded_num_reqs=padded_num_reqs)
|
|
758
799
|
return attn_metadata, None
|
|
759
800
|
|
|
760
801
|
def _sample_from_logits(
|
|
@@ -768,23 +809,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
768
809
|
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
|
769
810
|
kv_connector_output: Optional[KVConnectorOutput],
|
|
770
811
|
logits_indices_selector: Optional[List[int]] = None,
|
|
812
|
+
padded_num_reqs: Optional[int] = None,
|
|
771
813
|
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
|
|
772
|
-
padded_num_reqs
|
|
773
|
-
|
|
814
|
+
if padded_num_reqs is None:
|
|
815
|
+
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
|
|
816
|
+
self.input_batch.num_reqs, self.max_num_reqs)
|
|
817
|
+
|
|
818
|
+
sharding = None
|
|
819
|
+
if self.dp_size > 1:
|
|
820
|
+
sharding = NamedSharding(self.mesh,
|
|
821
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
822
|
+
|
|
774
823
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
775
|
-
self.mesh, self.input_batch, padded_num_reqs)
|
|
824
|
+
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
|
|
825
|
+
|
|
826
|
+
# TODO(pooyam): Should we move this to `_prepare_inputs`?
|
|
827
|
+
if tpu_sampling_metadata.do_sampling:
|
|
828
|
+
self.rng_params_for_sampling, step_rng = jax.random.split(
|
|
829
|
+
self.rng_params_for_sampling)
|
|
830
|
+
else:
|
|
831
|
+
step_rng = self.rng_params_for_sampling
|
|
832
|
+
|
|
776
833
|
if spec_decode_metadata is None:
|
|
777
834
|
next_tokens = sample(
|
|
778
|
-
|
|
835
|
+
step_rng,
|
|
779
836
|
self.mesh,
|
|
780
837
|
logits,
|
|
781
838
|
tpu_sampling_metadata,
|
|
782
839
|
)
|
|
783
840
|
else:
|
|
841
|
+
if tpu_sampling_metadata.do_sampling:
|
|
842
|
+
bonus_rng, rejection_rng = jax.random.split(step_rng)
|
|
843
|
+
else:
|
|
844
|
+
bonus_rng = step_rng
|
|
845
|
+
rejection_rng = step_rng
|
|
784
846
|
bonus_logits = self._select_from_array_fn(
|
|
785
847
|
logits, spec_decode_metadata.bonus_logits_indices)
|
|
786
848
|
bonus_token_ids = sample(
|
|
787
|
-
|
|
849
|
+
bonus_rng,
|
|
788
850
|
self.mesh,
|
|
789
851
|
bonus_logits,
|
|
790
852
|
tpu_sampling_metadata,
|
|
@@ -798,7 +860,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
798
860
|
target_logits=target_logits,
|
|
799
861
|
bonus_token_ids=bonus_token_ids,
|
|
800
862
|
sampling_metadata=tpu_sampling_metadata,
|
|
801
|
-
key=
|
|
863
|
+
key=rejection_rng,
|
|
802
864
|
)
|
|
803
865
|
|
|
804
866
|
if tpu_sampling_metadata.logprobs:
|
|
@@ -856,10 +918,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
856
918
|
|
|
857
919
|
if logprobs is not None:
|
|
858
920
|
# Map logprobs back to the pre-dp shuffling order
|
|
859
|
-
logprobs_lists =
|
|
860
|
-
|
|
861
|
-
logprobs_lists = _reorder_logits_indices(
|
|
862
|
-
logprobs_lists, logits_indices_selector)
|
|
921
|
+
logprobs_lists = _jax_logprobs_to_lists(
|
|
922
|
+
logprobs, logits_indices_selector)
|
|
863
923
|
|
|
864
924
|
else:
|
|
865
925
|
logprobs_lists = None
|
|
@@ -929,10 +989,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
929
989
|
|
|
930
990
|
if logprobs is not None:
|
|
931
991
|
# Map logprobs back to the pre-dp shuffling order
|
|
932
|
-
logprobs_lists = logprobs
|
|
933
|
-
|
|
934
|
-
logprobs_lists = _reorder_logits_indices(
|
|
935
|
-
logprobs_lists, logits_indices_selector)
|
|
992
|
+
logprobs_lists = _jax_logprobs_to_lists(logprobs,
|
|
993
|
+
logits_indices_selector)
|
|
936
994
|
else:
|
|
937
995
|
logprobs_lists = None
|
|
938
996
|
|
|
@@ -1280,16 +1338,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1280
1338
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1281
1339
|
padded_total_num_scheduled_tokens]
|
|
1282
1340
|
|
|
1283
|
-
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1284
|
-
for dp_rank in range(dp_size):
|
|
1285
|
-
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1286
|
-
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1287
|
-
|
|
1288
|
-
block_tables[
|
|
1289
|
-
req_offset:req_offset + _num_reqs, :self.
|
|
1290
|
-
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1291
|
-
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1292
|
-
|
|
1293
1341
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
|
|
1294
1342
|
dp_size]
|
|
1295
1343
|
seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
|
|
@@ -1331,20 +1379,59 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1331
1379
|
if self.uses_mrope:
|
|
1332
1380
|
positions = mrope_positions
|
|
1333
1381
|
|
|
1334
|
-
# Convert block_tables to 1D on cpu.
|
|
1335
|
-
block_tables = block_tables.reshape(-1)
|
|
1336
|
-
|
|
1337
1382
|
query_start_loc_cpu = query_start_loc
|
|
1338
1383
|
logits_indices_cpu = logits_indices
|
|
1339
1384
|
seq_lens_cpu = seq_lens
|
|
1340
1385
|
|
|
1341
|
-
(input_ids, positions,
|
|
1342
|
-
|
|
1386
|
+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
|
|
1387
|
+
request_distribution) = device_array(
|
|
1343
1388
|
self.mesh,
|
|
1344
|
-
(input_ids, positions,
|
|
1345
|
-
|
|
1389
|
+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
|
|
1390
|
+
request_distribution),
|
|
1346
1391
|
sharding=data_parallel_attn_sharding,
|
|
1347
1392
|
)
|
|
1393
|
+
|
|
1394
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1395
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
1396
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1397
|
+
self.kv_cache_config.kv_cache_groups):
|
|
1398
|
+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1399
|
+
max_num_reqs]
|
|
1400
|
+
for dp_rank in range(dp_size):
|
|
1401
|
+
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1402
|
+
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1403
|
+
|
|
1404
|
+
block_tables[
|
|
1405
|
+
req_offset:req_offset + _num_reqs, :self.
|
|
1406
|
+
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1407
|
+
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1408
|
+
# Convert block_tables to 1D on cpu.
|
|
1409
|
+
block_tables = block_tables.reshape(-1)
|
|
1410
|
+
block_tables = device_array(
|
|
1411
|
+
self.mesh,
|
|
1412
|
+
(block_tables),
|
|
1413
|
+
sharding=data_parallel_attn_sharding,
|
|
1414
|
+
)
|
|
1415
|
+
|
|
1416
|
+
attention_metadata_gid = AttentionMetadata(
|
|
1417
|
+
input_positions=positions,
|
|
1418
|
+
block_tables=block_tables,
|
|
1419
|
+
seq_lens=seq_lens,
|
|
1420
|
+
query_start_loc=query_start_loc,
|
|
1421
|
+
request_distribution=request_distribution,
|
|
1422
|
+
)
|
|
1423
|
+
|
|
1424
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1425
|
+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1426
|
+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1427
|
+
|
|
1428
|
+
if not self.use_hybrid_kvcache:
|
|
1429
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
1430
|
+
else:
|
|
1431
|
+
for layer_name in kv_cache_group.layer_names:
|
|
1432
|
+
attention_metadata_per_layer[
|
|
1433
|
+
layer_name] = attention_metadata_gid
|
|
1434
|
+
|
|
1348
1435
|
# Async scheduling: substitute placeholder tokens for DP
|
|
1349
1436
|
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
|
|
1350
1437
|
# Collect all token indices that need substitution across all DP ranks
|
|
@@ -1373,25 +1460,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1373
1460
|
padded_total_num_scheduled_tokens,
|
|
1374
1461
|
)
|
|
1375
1462
|
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
query_start_loc=query_start_loc,
|
|
1381
|
-
request_distribution=request_distribution,
|
|
1382
|
-
)
|
|
1383
|
-
|
|
1384
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1385
|
-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1386
|
-
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1387
|
-
|
|
1463
|
+
if self.use_hybrid_kvcache:
|
|
1464
|
+
attention_metadata = attention_metadata_per_layer
|
|
1465
|
+
else:
|
|
1466
|
+
attention_metadata = uniform_attention_metadata
|
|
1388
1467
|
return (
|
|
1389
1468
|
input_ids,
|
|
1469
|
+
positions,
|
|
1390
1470
|
attention_metadata,
|
|
1391
1471
|
sampling_metadata,
|
|
1392
1472
|
logits_indices,
|
|
1393
1473
|
spec_decode_metadata,
|
|
1394
1474
|
logits_indices_selector,
|
|
1475
|
+
padded_num_reqs,
|
|
1395
1476
|
)
|
|
1396
1477
|
|
|
1397
1478
|
def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
|
|
@@ -1492,9 +1573,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1492
1573
|
positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
|
|
1493
1574
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1494
1575
|
padded_total_num_scheduled_tokens]
|
|
1495
|
-
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1496
|
-
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
|
1497
|
-
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
|
1498
1576
|
|
|
1499
1577
|
# TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
|
|
1500
1578
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
|
|
@@ -1523,16 +1601,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1523
1601
|
self.mesh, self.input_batch, padded_num_reqs)
|
|
1524
1602
|
if self.uses_mrope:
|
|
1525
1603
|
positions = mrope_positions
|
|
1526
|
-
|
|
1527
|
-
# Convert block_tables to 1D on cpu.
|
|
1528
|
-
block_tables = block_tables.reshape(-1)
|
|
1529
|
-
|
|
1530
1604
|
query_start_loc_cpu = query_start_loc
|
|
1531
1605
|
seq_lens_cpu = seq_lens
|
|
1532
|
-
|
|
1606
|
+
|
|
1607
|
+
(input_ids, positions, query_start_loc, seq_lens,
|
|
1533
1608
|
logits_indices, request_distribution) = device_array(
|
|
1534
|
-
self.mesh, (input_ids, positions,
|
|
1535
|
-
|
|
1609
|
+
self.mesh, (input_ids, positions, query_start_loc, seq_lens,
|
|
1610
|
+
logits_indices, request_distribution))
|
|
1611
|
+
|
|
1612
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1613
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
1614
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1615
|
+
self.kv_cache_config.kv_cache_groups):
|
|
1616
|
+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1617
|
+
max_num_reqs]
|
|
1618
|
+
block_tables[:num_reqs] = (
|
|
1619
|
+
self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
|
|
1620
|
+
[:num_reqs])
|
|
1621
|
+
# Convert block_tables to 1D on cpu.
|
|
1622
|
+
block_tables = block_tables.reshape(-1)
|
|
1623
|
+
block_tables = device_array(self.mesh, (block_tables))
|
|
1624
|
+
|
|
1625
|
+
attention_metadata_gid = AttentionMetadata(
|
|
1626
|
+
input_positions=positions,
|
|
1627
|
+
block_tables=block_tables,
|
|
1628
|
+
seq_lens=seq_lens,
|
|
1629
|
+
query_start_loc=query_start_loc,
|
|
1630
|
+
request_distribution=request_distribution)
|
|
1631
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1632
|
+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1633
|
+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1634
|
+
|
|
1635
|
+
if not self.use_hybrid_kvcache:
|
|
1636
|
+
# all layers share the same attention metadata
|
|
1637
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
1638
|
+
else:
|
|
1639
|
+
for layer_name in kv_cache_group.layer_names:
|
|
1640
|
+
attention_metadata_per_layer[
|
|
1641
|
+
layer_name] = attention_metadata_gid
|
|
1536
1642
|
|
|
1537
1643
|
if self.scheduler_config.async_scheduling and len(
|
|
1538
1644
|
token_in_tpu_cur_input_indices) > 0:
|
|
@@ -1545,20 +1651,15 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1545
1651
|
self.lora_utils.set_active_loras(
|
|
1546
1652
|
num_scheduled_tokens_per_req, total_num_scheduled_tokens,
|
|
1547
1653
|
padded_total_num_scheduled_tokens)
|
|
1548
|
-
|
|
1549
|
-
attention_metadata = AttentionMetadata(
|
|
1550
|
-
input_positions=positions,
|
|
1551
|
-
block_tables=block_tables,
|
|
1552
|
-
seq_lens=seq_lens,
|
|
1553
|
-
query_start_loc=query_start_loc,
|
|
1554
|
-
request_distribution=request_distribution)
|
|
1555
|
-
|
|
1556
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1557
|
-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1558
|
-
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1559
1654
|
logits_indices_selector = None
|
|
1560
|
-
|
|
1561
|
-
|
|
1655
|
+
|
|
1656
|
+
if self.use_hybrid_kvcache:
|
|
1657
|
+
attention_metadata = attention_metadata_per_layer
|
|
1658
|
+
else:
|
|
1659
|
+
attention_metadata = uniform_attention_metadata
|
|
1660
|
+
return (input_ids, positions, attention_metadata, sampling_metadata,
|
|
1661
|
+
logits_indices, spec_decode_metadata, logits_indices_selector,
|
|
1662
|
+
padded_num_reqs)
|
|
1562
1663
|
|
|
1563
1664
|
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1564
1665
|
mm_embeds: list[jax.Array]):
|
|
@@ -9,10 +9,13 @@ import numpy as np
|
|
|
9
9
|
from vllm.config import VllmConfig
|
|
10
10
|
|
|
11
11
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
+
from tpu_inference.logger import init_logger
|
|
12
13
|
from tpu_inference.models.common.model_loader import get_model
|
|
13
14
|
from tpu_inference.runner import utils as runner_utils
|
|
14
15
|
from tpu_inference.utils import device_array
|
|
15
16
|
|
|
17
|
+
logger = init_logger(__name__)
|
|
18
|
+
|
|
16
19
|
|
|
17
20
|
class Eagle3Proposer:
|
|
18
21
|
"""A proposer for speculative decoding using the Eagle3 method.
|
|
@@ -51,8 +54,22 @@ class Eagle3Proposer:
|
|
|
51
54
|
"""Loads the draft model."""
|
|
52
55
|
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
|
|
53
56
|
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
|
|
54
|
-
|
|
55
|
-
self.state.model
|
|
57
|
+
|
|
58
|
+
draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
|
|
59
|
+
if draft_embed_tokens is None or ~jnp.any(
|
|
60
|
+
draft_embed_tokens.embedding):
|
|
61
|
+
logger.info(
|
|
62
|
+
"Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
|
|
63
|
+
)
|
|
64
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
65
|
+
elif jnp.array_equal(draft_embed_tokens.embedding,
|
|
66
|
+
target_model.model.embed.embedding):
|
|
67
|
+
logger.info(
|
|
68
|
+
"Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
|
|
69
|
+
)
|
|
70
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
71
|
+
else:
|
|
72
|
+
logger.info("Draft model has its own embed_tokens.")
|
|
56
73
|
|
|
57
74
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
58
75
|
def _prepare_input_ids(
|
tpu_inference/tpu_info.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
|
|
4
4
|
import requests
|
|
5
5
|
|
|
6
|
+
from tpu_inference import envs
|
|
6
7
|
from tpu_inference.logger import init_logger
|
|
7
8
|
|
|
8
9
|
logger = init_logger(__name__)
|
|
@@ -32,14 +33,14 @@ def get_tpu_metadata(key: str = "") -> str:
|
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
def get_tpu_type() -> str:
|
|
35
|
-
tpu_type =
|
|
36
|
+
tpu_type = envs.TPU_ACCELERATOR_TYPE
|
|
36
37
|
if tpu_type is None:
|
|
37
38
|
tpu_type = get_tpu_metadata(key="accelerator-type")
|
|
38
39
|
return tpu_type
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
def get_node_name() -> str:
|
|
42
|
-
tpu_name =
|
|
43
|
+
tpu_name = envs.TPU_NAME
|
|
43
44
|
if not tpu_name:
|
|
44
45
|
tpu_name = get_tpu_metadata(key="instance-id")
|
|
45
46
|
return tpu_name
|
|
@@ -47,7 +48,7 @@ def get_node_name() -> str:
|
|
|
47
48
|
|
|
48
49
|
def get_node_worker_id() -> int:
|
|
49
50
|
"""For multi-host TPU VM, this returns the worker id for the current node."""
|
|
50
|
-
worker_id =
|
|
51
|
+
worker_id = envs.TPU_WORKER_ID
|
|
51
52
|
if worker_id is None:
|
|
52
53
|
worker_id = get_tpu_metadata(key="agent-worker-number")
|
|
53
54
|
if worker_id is None:
|
tpu_inference/utils.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
-
import os
|
|
3
2
|
import time
|
|
4
3
|
from collections import defaultdict
|
|
5
4
|
from collections.abc import Sequence
|
|
@@ -14,8 +13,10 @@ from jax._src import mesh as mesh_lib
|
|
|
14
13
|
from jax._src import xla_bridge as xb
|
|
15
14
|
from jax._src.lib import xla_client as xc
|
|
16
15
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
17
|
-
from vllm import envs
|
|
16
|
+
from vllm import envs as vllm_envs
|
|
17
|
+
from vllm import utils
|
|
18
18
|
|
|
19
|
+
from tpu_inference import envs
|
|
19
20
|
from tpu_inference.logger import init_logger
|
|
20
21
|
|
|
21
22
|
GBYTES = 1024 * 1024 * 1024
|
|
@@ -57,10 +58,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
|
|
|
57
58
|
|
|
58
59
|
def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
|
|
59
60
|
usage = []
|
|
60
|
-
if
|
|
61
|
+
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
|
|
61
62
|
return pathways_hbm_usage_gb(devices)
|
|
62
63
|
|
|
63
|
-
multihost_backend =
|
|
64
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
64
65
|
if multihost_backend == "ray":
|
|
65
66
|
# MemoryStats is only supported for addressable PjRt devices.
|
|
66
67
|
# Assume all the devices have similar memory usage for now.
|