tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +4 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +77 -54
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +9 -9
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +9 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +2 -1
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +21 -8
- tpu_inference/models/vllm/vllm_model_wrapper.py +4 -4
- tpu_inference/platforms/tpu_platform.py +5 -2
- tpu_inference/runner/compilation_manager.py +33 -15
- tpu_inference/runner/kv_cache_manager.py +8 -2
- tpu_inference/runner/tpu_runner.py +187 -99
- tpu_inference/spec_decode/jax/eagle3.py +2 -1
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +158 -22
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +34 -39
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import functools
|
|
2
|
-
import math
|
|
3
2
|
from typing import TYPE_CHECKING, Dict, List
|
|
4
3
|
|
|
5
4
|
import jax
|
|
6
5
|
import jax.numpy as jnp
|
|
6
|
+
import numpy as np
|
|
7
7
|
import vllm.envs as envs
|
|
8
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
9
|
from torchax.ops.mappings import t2j_dtype
|
|
10
10
|
from vllm.attention import Attention
|
|
11
11
|
from vllm.attention.backends.abstract import AttentionType
|
|
12
12
|
from vllm.config import get_layers_from_vllm_config
|
|
13
|
+
from vllm.utils.math_utils import cdiv
|
|
13
14
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
14
15
|
KVCacheSpec, MLAAttentionSpec,
|
|
15
16
|
SlidingWindowSpec)
|
|
@@ -175,6 +176,11 @@ class KVCacheManager:
|
|
|
175
176
|
)
|
|
176
177
|
self.runner.input_batch = new_input_batch
|
|
177
178
|
self.runner.persistent_batch_manager.input_batch = new_input_batch
|
|
179
|
+
self.runner.block_tables_cpu = [
|
|
180
|
+
np.zeros((self.runner.max_num_reqs,
|
|
181
|
+
cdiv(self.runner.max_model_len, block_size)),
|
|
182
|
+
dtype=np.int32) for block_size in block_sizes
|
|
183
|
+
]
|
|
178
184
|
|
|
179
185
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
180
186
|
self.maybe_reinitialize_input_batch(kv_cache_config)
|
|
@@ -190,7 +196,7 @@ class KVCacheManager:
|
|
|
190
196
|
num_blocks = kv_cache_tensor.size // page_size_bytes
|
|
191
197
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
192
198
|
# num_blocks must be a multiple of dp_size
|
|
193
|
-
num_blocks =
|
|
199
|
+
num_blocks = (num_blocks // dp_size) * dp_size
|
|
194
200
|
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
195
201
|
kv_cache = create_kv_caches(
|
|
196
202
|
num_blocks=num_blocks,
|
|
@@ -153,6 +153,7 @@ class ExecuteModelState:
|
|
|
153
153
|
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
|
154
154
|
kv_connector_output: Optional[KVConnectorOutput]
|
|
155
155
|
logits_indices_selector: Optional[List[int]] = None
|
|
156
|
+
padded_num_reqs: Optional[int] = None
|
|
156
157
|
|
|
157
158
|
|
|
158
159
|
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
|
|
@@ -190,18 +191,28 @@ def _substitute_placeholder_token(
|
|
|
190
191
|
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
|
|
191
192
|
|
|
192
193
|
|
|
193
|
-
def
|
|
194
|
+
def _jax_logprobs_to_lists(logprobs_tensors,
|
|
195
|
+
logits_indices_selector=None,
|
|
196
|
+
cu_num_generated_tokens=None):
|
|
197
|
+
"""Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
|
|
198
|
+
log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
|
|
199
|
+
logprobs_list = logprobs_tensors.logprobs.tolist()
|
|
200
|
+
selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
|
|
201
|
+
|
|
202
|
+
if logits_indices_selector is not None:
|
|
203
|
+
log_token_ids_list = [
|
|
204
|
+
log_token_ids_list[i] for i in logits_indices_selector
|
|
205
|
+
]
|
|
206
|
+
logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
|
|
207
|
+
selected_token_ranks_list = [
|
|
208
|
+
selected_token_ranks_list[i] for i in logits_indices_selector
|
|
209
|
+
]
|
|
210
|
+
|
|
194
211
|
return LogprobsLists(
|
|
195
|
-
logprob_token_ids=
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
|
|
200
|
-
sampled_token_ranks=[
|
|
201
|
-
logprobs_lists.sampled_token_ranks[i]
|
|
202
|
-
for i in logits_indices_selector
|
|
203
|
-
],
|
|
204
|
-
cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
|
|
212
|
+
logprob_token_ids=np.asarray(log_token_ids_list),
|
|
213
|
+
logprobs=np.asarray(logprobs_list),
|
|
214
|
+
sampled_token_ranks=np.asarray(selected_token_ranks_list),
|
|
215
|
+
cu_num_generated_tokens=cu_num_generated_tokens,
|
|
205
216
|
)
|
|
206
217
|
|
|
207
218
|
|
|
@@ -211,6 +222,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
211
222
|
self,
|
|
212
223
|
vllm_config: VllmConfig,
|
|
213
224
|
devices: List[Any],
|
|
225
|
+
rank: int = 0,
|
|
226
|
+
is_first_rank: bool = True,
|
|
227
|
+
is_last_rank: bool = True,
|
|
214
228
|
):
|
|
215
229
|
self.vllm_config = vllm_config
|
|
216
230
|
self.model_config = vllm_config.model_config
|
|
@@ -423,8 +437,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
423
437
|
|
|
424
438
|
self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
425
439
|
self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
426
|
-
self.
|
|
427
|
-
|
|
440
|
+
# Note: self.input_batch and self.block_tables_cpu are both initialized
|
|
441
|
+
# with only 1 block_size. For hybrid kv cache, it will be re-init
|
|
442
|
+
# in kv_cache_manager's maybe_reinitialize_input_batch.
|
|
443
|
+
self.block_tables_cpu = [
|
|
444
|
+
np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
|
|
445
|
+
dtype=np.int32)
|
|
446
|
+
]
|
|
447
|
+
|
|
428
448
|
self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
|
|
429
449
|
dtype=np.int32)
|
|
430
450
|
self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
|
|
@@ -458,9 +478,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
458
478
|
|
|
459
479
|
# tensors for structured decoding
|
|
460
480
|
self.vocab_size = self.model_config.get_vocab_size()
|
|
461
|
-
if self.lora_config is not None:
|
|
462
|
-
# lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
|
|
463
|
-
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
|
464
481
|
self.grammar_bitmask_cpu = np.zeros(
|
|
465
482
|
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
|
|
466
483
|
dtype=np.int32,
|
|
@@ -505,9 +522,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
505
522
|
|
|
506
523
|
self.rng_params_for_sampling = nnx.Rngs(
|
|
507
524
|
jax.random.key(self.model_config.seed)).params()
|
|
508
|
-
self.is_multimodal_model = (
|
|
509
|
-
|
|
510
|
-
|
|
525
|
+
self.is_multimodal_model = (
|
|
526
|
+
self.model_config.is_multimodal_model
|
|
527
|
+
and self.get_multimodal_embeddings_fn is not None and hasattr(
|
|
528
|
+
self.model_config.hf_config, "architectures"
|
|
529
|
+
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
530
|
+
and len(self.model_config.hf_config.architectures) >= 1
|
|
531
|
+
and self.model_config.hf_config.architectures[0]
|
|
532
|
+
!= "Llama4ForConditionalGeneration")
|
|
511
533
|
|
|
512
534
|
logger.info(f"Init model | "
|
|
513
535
|
f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
|
|
@@ -520,6 +542,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
520
542
|
|
|
521
543
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
522
544
|
self.kv_cache_config = kv_cache_config
|
|
545
|
+
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
|
|
523
546
|
self.kv_caches = []
|
|
524
547
|
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
|
|
525
548
|
if has_kv_transfer_group():
|
|
@@ -550,16 +573,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
550
573
|
|
|
551
574
|
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
552
575
|
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
553
|
-
logits_indices_selector
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
576
|
+
logits_indices_selector,
|
|
577
|
+
padded_num_reqs) = (self.execute_model_state.scheduler_output,
|
|
578
|
+
self.execute_model_state.attn_metadata,
|
|
579
|
+
self.execute_model_state.input_ids,
|
|
580
|
+
self.execute_model_state.hidden_states,
|
|
581
|
+
self.execute_model_state.logits,
|
|
582
|
+
self.execute_model_state.aux_hidden_states,
|
|
583
|
+
self.execute_model_state.spec_decode_metadata,
|
|
584
|
+
self.execute_model_state.kv_connector_output,
|
|
585
|
+
self.execute_model_state.logits_indices_selector,
|
|
586
|
+
self.execute_model_state.padded_num_reqs)
|
|
563
587
|
self.execute_model_state = None
|
|
564
588
|
|
|
565
589
|
if grammar_output is not None:
|
|
@@ -573,12 +597,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
573
597
|
logits,
|
|
574
598
|
arange,
|
|
575
599
|
)
|
|
576
|
-
return self._sample_from_logits(
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
kv_connector_output,
|
|
581
|
-
logits_indices_selector)
|
|
600
|
+
return self._sample_from_logits(
|
|
601
|
+
scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
602
|
+
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
603
|
+
logits_indices_selector, padded_num_reqs)
|
|
582
604
|
|
|
583
605
|
def _modify_prev_results(self):
|
|
584
606
|
# If copy to host has not been done, we just wait.
|
|
@@ -687,13 +709,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
687
709
|
# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
|
|
688
710
|
(
|
|
689
711
|
input_ids,
|
|
712
|
+
input_positions,
|
|
690
713
|
attn_metadata,
|
|
691
714
|
_,
|
|
692
715
|
logits_indices,
|
|
693
716
|
spec_decode_metadata,
|
|
694
717
|
logits_indices_selector,
|
|
718
|
+
padded_num_reqs,
|
|
695
719
|
) = self._prepare_inputs(scheduler_output)
|
|
696
720
|
|
|
721
|
+
is_llama_guard_4 = (
|
|
722
|
+
hasattr(
|
|
723
|
+
self.model_config.hf_config, "architectures"
|
|
724
|
+
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
725
|
+
and len(self.model_config.hf_config.architectures) >= 1
|
|
726
|
+
and self.model_config.hf_config.architectures[0]
|
|
727
|
+
== "Llama4ForConditionalGeneration")
|
|
728
|
+
|
|
697
729
|
# multi-modal support
|
|
698
730
|
if self.is_multimodal_model:
|
|
699
731
|
# Run the multimodal encoder if any.
|
|
@@ -701,6 +733,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
701
733
|
self.mm_manager.execute_mm_encoder(scheduler_output)
|
|
702
734
|
mm_embeds = self.mm_manager.gather_mm_embeddings(
|
|
703
735
|
scheduler_output, input_ids.shape[0])
|
|
736
|
+
#TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
|
|
737
|
+
elif is_llama_guard_4 and any(
|
|
738
|
+
self.mm_manager.runner.requests[req_id].mm_features
|
|
739
|
+
for req_id in self.mm_manager.runner.input_batch.req_ids):
|
|
740
|
+
raise NotImplementedError(
|
|
741
|
+
"Llama Guard 4 (JAX) currently supports only text inputs. "
|
|
742
|
+
"Multimodal processing not yet implemented.")
|
|
704
743
|
else:
|
|
705
744
|
mm_embeds = []
|
|
706
745
|
|
|
@@ -733,6 +772,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
733
772
|
input_ids,
|
|
734
773
|
attn_metadata,
|
|
735
774
|
inputs_embeds,
|
|
775
|
+
input_positions,
|
|
736
776
|
tuple(self.layer_name_to_kvcache_index.items()),
|
|
737
777
|
lora_metadata,
|
|
738
778
|
)
|
|
@@ -754,7 +794,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
754
794
|
aux_hidden_states=aux_hidden_states,
|
|
755
795
|
spec_decode_metadata=spec_decode_metadata,
|
|
756
796
|
kv_connector_output=kv_connector_output,
|
|
757
|
-
logits_indices_selector=logits_indices_selector
|
|
797
|
+
logits_indices_selector=logits_indices_selector,
|
|
798
|
+
padded_num_reqs=padded_num_reqs)
|
|
758
799
|
return attn_metadata, None
|
|
759
800
|
|
|
760
801
|
def _sample_from_logits(
|
|
@@ -768,11 +809,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
768
809
|
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
|
769
810
|
kv_connector_output: Optional[KVConnectorOutput],
|
|
770
811
|
logits_indices_selector: Optional[List[int]] = None,
|
|
812
|
+
padded_num_reqs: Optional[int] = None,
|
|
771
813
|
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
|
|
772
|
-
padded_num_reqs
|
|
773
|
-
|
|
814
|
+
if padded_num_reqs is None:
|
|
815
|
+
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
|
|
816
|
+
self.input_batch.num_reqs, self.max_num_reqs)
|
|
817
|
+
|
|
818
|
+
sharding = None
|
|
819
|
+
if self.dp_size > 1:
|
|
820
|
+
sharding = NamedSharding(self.mesh,
|
|
821
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
822
|
+
|
|
774
823
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
775
|
-
self.mesh, self.input_batch, padded_num_reqs)
|
|
824
|
+
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
|
|
776
825
|
if spec_decode_metadata is None:
|
|
777
826
|
next_tokens = sample(
|
|
778
827
|
self.rng_params_for_sampling,
|
|
@@ -856,10 +905,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
856
905
|
|
|
857
906
|
if logprobs is not None:
|
|
858
907
|
# Map logprobs back to the pre-dp shuffling order
|
|
859
|
-
logprobs_lists =
|
|
860
|
-
|
|
861
|
-
logprobs_lists = _reorder_logits_indices(
|
|
862
|
-
logprobs_lists, logits_indices_selector)
|
|
908
|
+
logprobs_lists = _jax_logprobs_to_lists(
|
|
909
|
+
logprobs, logits_indices_selector)
|
|
863
910
|
|
|
864
911
|
else:
|
|
865
912
|
logprobs_lists = None
|
|
@@ -929,10 +976,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
929
976
|
|
|
930
977
|
if logprobs is not None:
|
|
931
978
|
# Map logprobs back to the pre-dp shuffling order
|
|
932
|
-
logprobs_lists = logprobs
|
|
933
|
-
|
|
934
|
-
logprobs_lists = _reorder_logits_indices(
|
|
935
|
-
logprobs_lists, logits_indices_selector)
|
|
979
|
+
logprobs_lists = _jax_logprobs_to_lists(logprobs,
|
|
980
|
+
logits_indices_selector)
|
|
936
981
|
else:
|
|
937
982
|
logprobs_lists = None
|
|
938
983
|
|
|
@@ -1280,16 +1325,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1280
1325
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1281
1326
|
padded_total_num_scheduled_tokens]
|
|
1282
1327
|
|
|
1283
|
-
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1284
|
-
for dp_rank in range(dp_size):
|
|
1285
|
-
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1286
|
-
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1287
|
-
|
|
1288
|
-
block_tables[
|
|
1289
|
-
req_offset:req_offset + _num_reqs, :self.
|
|
1290
|
-
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1291
|
-
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1292
|
-
|
|
1293
1328
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
|
|
1294
1329
|
dp_size]
|
|
1295
1330
|
seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
|
|
@@ -1331,20 +1366,59 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1331
1366
|
if self.uses_mrope:
|
|
1332
1367
|
positions = mrope_positions
|
|
1333
1368
|
|
|
1334
|
-
# Convert block_tables to 1D on cpu.
|
|
1335
|
-
block_tables = block_tables.reshape(-1)
|
|
1336
|
-
|
|
1337
1369
|
query_start_loc_cpu = query_start_loc
|
|
1338
1370
|
logits_indices_cpu = logits_indices
|
|
1339
1371
|
seq_lens_cpu = seq_lens
|
|
1340
1372
|
|
|
1341
|
-
(input_ids, positions,
|
|
1342
|
-
|
|
1373
|
+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
|
|
1374
|
+
request_distribution) = device_array(
|
|
1343
1375
|
self.mesh,
|
|
1344
|
-
(input_ids, positions,
|
|
1345
|
-
|
|
1376
|
+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
|
|
1377
|
+
request_distribution),
|
|
1346
1378
|
sharding=data_parallel_attn_sharding,
|
|
1347
1379
|
)
|
|
1380
|
+
|
|
1381
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1382
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
1383
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1384
|
+
self.kv_cache_config.kv_cache_groups):
|
|
1385
|
+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1386
|
+
max_num_reqs]
|
|
1387
|
+
for dp_rank in range(dp_size):
|
|
1388
|
+
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1389
|
+
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1390
|
+
|
|
1391
|
+
block_tables[
|
|
1392
|
+
req_offset:req_offset + _num_reqs, :self.
|
|
1393
|
+
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1394
|
+
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1395
|
+
# Convert block_tables to 1D on cpu.
|
|
1396
|
+
block_tables = block_tables.reshape(-1)
|
|
1397
|
+
block_tables = device_array(
|
|
1398
|
+
self.mesh,
|
|
1399
|
+
(block_tables),
|
|
1400
|
+
sharding=data_parallel_attn_sharding,
|
|
1401
|
+
)
|
|
1402
|
+
|
|
1403
|
+
attention_metadata_gid = AttentionMetadata(
|
|
1404
|
+
input_positions=positions,
|
|
1405
|
+
block_tables=block_tables,
|
|
1406
|
+
seq_lens=seq_lens,
|
|
1407
|
+
query_start_loc=query_start_loc,
|
|
1408
|
+
request_distribution=request_distribution,
|
|
1409
|
+
)
|
|
1410
|
+
|
|
1411
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1412
|
+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1413
|
+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1414
|
+
|
|
1415
|
+
if not self.use_hybrid_kvcache:
|
|
1416
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
1417
|
+
else:
|
|
1418
|
+
for layer_name in kv_cache_group.layer_names:
|
|
1419
|
+
attention_metadata_per_layer[
|
|
1420
|
+
layer_name] = attention_metadata_gid
|
|
1421
|
+
|
|
1348
1422
|
# Async scheduling: substitute placeholder tokens for DP
|
|
1349
1423
|
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
|
|
1350
1424
|
# Collect all token indices that need substitution across all DP ranks
|
|
@@ -1373,25 +1447,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1373
1447
|
padded_total_num_scheduled_tokens,
|
|
1374
1448
|
)
|
|
1375
1449
|
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
query_start_loc=query_start_loc,
|
|
1381
|
-
request_distribution=request_distribution,
|
|
1382
|
-
)
|
|
1383
|
-
|
|
1384
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1385
|
-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1386
|
-
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1387
|
-
|
|
1450
|
+
if self.use_hybrid_kvcache:
|
|
1451
|
+
attention_metadata = attention_metadata_per_layer
|
|
1452
|
+
else:
|
|
1453
|
+
attention_metadata = uniform_attention_metadata
|
|
1388
1454
|
return (
|
|
1389
1455
|
input_ids,
|
|
1456
|
+
positions,
|
|
1390
1457
|
attention_metadata,
|
|
1391
1458
|
sampling_metadata,
|
|
1392
1459
|
logits_indices,
|
|
1393
1460
|
spec_decode_metadata,
|
|
1394
1461
|
logits_indices_selector,
|
|
1462
|
+
padded_num_reqs,
|
|
1395
1463
|
)
|
|
1396
1464
|
|
|
1397
1465
|
def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
|
|
@@ -1492,9 +1560,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1492
1560
|
positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
|
|
1493
1561
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1494
1562
|
padded_total_num_scheduled_tokens]
|
|
1495
|
-
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1496
|
-
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
|
1497
|
-
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
|
1498
1563
|
|
|
1499
1564
|
# TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
|
|
1500
1565
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
|
|
@@ -1523,16 +1588,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1523
1588
|
self.mesh, self.input_batch, padded_num_reqs)
|
|
1524
1589
|
if self.uses_mrope:
|
|
1525
1590
|
positions = mrope_positions
|
|
1526
|
-
|
|
1527
|
-
# Convert block_tables to 1D on cpu.
|
|
1528
|
-
block_tables = block_tables.reshape(-1)
|
|
1529
|
-
|
|
1530
1591
|
query_start_loc_cpu = query_start_loc
|
|
1531
1592
|
seq_lens_cpu = seq_lens
|
|
1532
|
-
|
|
1593
|
+
|
|
1594
|
+
(input_ids, positions, query_start_loc, seq_lens,
|
|
1533
1595
|
logits_indices, request_distribution) = device_array(
|
|
1534
|
-
self.mesh, (input_ids, positions,
|
|
1535
|
-
|
|
1596
|
+
self.mesh, (input_ids, positions, query_start_loc, seq_lens,
|
|
1597
|
+
logits_indices, request_distribution))
|
|
1598
|
+
|
|
1599
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1600
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
1601
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1602
|
+
self.kv_cache_config.kv_cache_groups):
|
|
1603
|
+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1604
|
+
max_num_reqs]
|
|
1605
|
+
block_tables[:num_reqs] = (
|
|
1606
|
+
self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
|
|
1607
|
+
[:num_reqs])
|
|
1608
|
+
# Convert block_tables to 1D on cpu.
|
|
1609
|
+
block_tables = block_tables.reshape(-1)
|
|
1610
|
+
block_tables = device_array(self.mesh, (block_tables))
|
|
1611
|
+
|
|
1612
|
+
attention_metadata_gid = AttentionMetadata(
|
|
1613
|
+
input_positions=positions,
|
|
1614
|
+
block_tables=block_tables,
|
|
1615
|
+
seq_lens=seq_lens,
|
|
1616
|
+
query_start_loc=query_start_loc,
|
|
1617
|
+
request_distribution=request_distribution)
|
|
1618
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1619
|
+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1620
|
+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1621
|
+
|
|
1622
|
+
if not self.use_hybrid_kvcache:
|
|
1623
|
+
# all layers share the same attention metadata
|
|
1624
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
1625
|
+
else:
|
|
1626
|
+
for layer_name in kv_cache_group.layer_names:
|
|
1627
|
+
attention_metadata_per_layer[
|
|
1628
|
+
layer_name] = attention_metadata_gid
|
|
1536
1629
|
|
|
1537
1630
|
if self.scheduler_config.async_scheduling and len(
|
|
1538
1631
|
token_in_tpu_cur_input_indices) > 0:
|
|
@@ -1545,20 +1638,15 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1545
1638
|
self.lora_utils.set_active_loras(
|
|
1546
1639
|
num_scheduled_tokens_per_req, total_num_scheduled_tokens,
|
|
1547
1640
|
padded_total_num_scheduled_tokens)
|
|
1548
|
-
|
|
1549
|
-
attention_metadata = AttentionMetadata(
|
|
1550
|
-
input_positions=positions,
|
|
1551
|
-
block_tables=block_tables,
|
|
1552
|
-
seq_lens=seq_lens,
|
|
1553
|
-
query_start_loc=query_start_loc,
|
|
1554
|
-
request_distribution=request_distribution)
|
|
1555
|
-
|
|
1556
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1557
|
-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1558
|
-
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1559
1641
|
logits_indices_selector = None
|
|
1560
|
-
|
|
1561
|
-
|
|
1642
|
+
|
|
1643
|
+
if self.use_hybrid_kvcache:
|
|
1644
|
+
attention_metadata = attention_metadata_per_layer
|
|
1645
|
+
else:
|
|
1646
|
+
attention_metadata = uniform_attention_metadata
|
|
1647
|
+
return (input_ids, positions, attention_metadata, sampling_metadata,
|
|
1648
|
+
logits_indices, spec_decode_metadata, logits_indices_selector,
|
|
1649
|
+
padded_num_reqs)
|
|
1562
1650
|
|
|
1563
1651
|
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1564
1652
|
mm_embeds: list[jax.Array]):
|
|
@@ -51,7 +51,8 @@ class Eagle3Proposer:
|
|
|
51
51
|
"""Loads the draft model."""
|
|
52
52
|
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
|
|
53
53
|
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
|
|
54
|
-
|
|
54
|
+
if 'embed_tokens' in self.state.model:
|
|
55
|
+
del self.state.model['embed_tokens']
|
|
55
56
|
self.state.model.embed_tokens = target_model.model.embed
|
|
56
57
|
|
|
57
58
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
tpu_inference/tpu_info.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
|
|
4
4
|
import requests
|
|
5
5
|
|
|
6
|
+
from tpu_inference import envs
|
|
6
7
|
from tpu_inference.logger import init_logger
|
|
7
8
|
|
|
8
9
|
logger = init_logger(__name__)
|
|
@@ -32,14 +33,14 @@ def get_tpu_metadata(key: str = "") -> str:
|
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
def get_tpu_type() -> str:
|
|
35
|
-
tpu_type =
|
|
36
|
+
tpu_type = envs.TPU_ACCELERATOR_TYPE
|
|
36
37
|
if tpu_type is None:
|
|
37
38
|
tpu_type = get_tpu_metadata(key="accelerator-type")
|
|
38
39
|
return tpu_type
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
def get_node_name() -> str:
|
|
42
|
-
tpu_name =
|
|
43
|
+
tpu_name = envs.TPU_NAME
|
|
43
44
|
if not tpu_name:
|
|
44
45
|
tpu_name = get_tpu_metadata(key="instance-id")
|
|
45
46
|
return tpu_name
|
|
@@ -47,7 +48,7 @@ def get_node_name() -> str:
|
|
|
47
48
|
|
|
48
49
|
def get_node_worker_id() -> int:
|
|
49
50
|
"""For multi-host TPU VM, this returns the worker id for the current node."""
|
|
50
|
-
worker_id =
|
|
51
|
+
worker_id = envs.TPU_WORKER_ID
|
|
51
52
|
if worker_id is None:
|
|
52
53
|
worker_id = get_tpu_metadata(key="agent-worker-number")
|
|
53
54
|
if worker_id is None:
|
tpu_inference/utils.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
-
import os
|
|
3
2
|
import time
|
|
4
3
|
from collections import defaultdict
|
|
5
4
|
from collections.abc import Sequence
|
|
@@ -14,8 +13,10 @@ from jax._src import mesh as mesh_lib
|
|
|
14
13
|
from jax._src import xla_bridge as xb
|
|
15
14
|
from jax._src.lib import xla_client as xc
|
|
16
15
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
17
|
-
from vllm import envs
|
|
16
|
+
from vllm import envs as vllm_envs
|
|
17
|
+
from vllm import utils
|
|
18
18
|
|
|
19
|
+
from tpu_inference import envs
|
|
19
20
|
from tpu_inference.logger import init_logger
|
|
20
21
|
|
|
21
22
|
GBYTES = 1024 * 1024 * 1024
|
|
@@ -57,10 +58,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
|
|
|
57
58
|
|
|
58
59
|
def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
|
|
59
60
|
usage = []
|
|
60
|
-
if
|
|
61
|
+
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
|
|
61
62
|
return pathways_hbm_usage_gb(devices)
|
|
62
63
|
|
|
63
|
-
multihost_backend =
|
|
64
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
64
65
|
if multihost_backend == "ray":
|
|
65
66
|
# MemoryStats is only supported for addressable PjRt devices.
|
|
66
67
|
# Assume all the devices have similar memory usage for now.
|