tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +182 -0
- tests/test_utils.py +23 -14
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/core_tpu.py +17 -9
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +110 -64
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
- tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/quantization/__init__.py +7 -3
- tpu_inference/layers/vllm/quantization/awq.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +12 -11
- tpu_inference/models/jax/llama3.py +4 -3
- tpu_inference/models/jax/llama_eagle3.py +9 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +3 -2
- tpu_inference/models/jax/qwen2_5_vl.py +4 -3
- tpu_inference/models/jax/qwen3.py +3 -2
- tpu_inference/models/jax/utils/weight_utils.py +21 -8
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
- tpu_inference/platforms/tpu_platform.py +17 -7
- tpu_inference/runner/compilation_manager.py +37 -17
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +8 -2
- tpu_inference/runner/tpu_runner.py +199 -87
- tpu_inference/spec_decode/jax/eagle3.py +2 -1
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +7 -6
- tpu_inference/worker/tpu_worker.py +159 -23
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +52 -54
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
- /tpu_inference/layers/{jax → common}/sharding.py +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import time
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import jax
|
|
6
6
|
import jax.numpy as jnp
|
|
@@ -10,10 +10,10 @@ from jax.sharding import NamedSharding, PartitionSpec
|
|
|
10
10
|
|
|
11
11
|
from tpu_inference.core.disagg_utils import is_disagg_enabled
|
|
12
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
13
14
|
from tpu_inference.layers.jax.sample.sampling import sample
|
|
14
15
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
15
16
|
TPUSupportedSamplingMetadata
|
|
16
|
-
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
17
17
|
from tpu_inference.logger import init_logger
|
|
18
18
|
from tpu_inference.utils import device_array
|
|
19
19
|
|
|
@@ -135,12 +135,6 @@ class CompilationManager:
|
|
|
135
135
|
ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
|
|
136
136
|
|
|
137
137
|
# Keep existing pattern for complex array operations
|
|
138
|
-
block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
|
|
139
|
-
block_tables = block_tables.reshape(-1)
|
|
140
|
-
block_tables = device_array(self.runner.mesh,
|
|
141
|
-
block_tables,
|
|
142
|
-
sharding=dp_sharding)
|
|
143
|
-
|
|
144
138
|
seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
|
|
145
139
|
jnp.int32, dp_sharding)
|
|
146
140
|
query_start_loc = self._create_dummy_tensor(
|
|
@@ -152,26 +146,45 @@ class CompilationManager:
|
|
|
152
146
|
request_distribution,
|
|
153
147
|
sharding=dp_sharding)
|
|
154
148
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
149
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
150
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
151
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
152
|
+
self.runner.kv_cache_config.kv_cache_groups):
|
|
153
|
+
block_tables = self.runner.block_tables_cpu[
|
|
154
|
+
kv_cache_gid][:self.runner.max_num_reqs]
|
|
155
|
+
block_tables = block_tables.reshape(-1)
|
|
156
|
+
block_tables = device_array(self.runner.mesh,
|
|
157
|
+
block_tables,
|
|
158
|
+
sharding=dp_sharding)
|
|
159
|
+
|
|
160
|
+
attention_metadata_gid = AttentionMetadata(
|
|
161
|
+
input_positions=positions,
|
|
162
|
+
block_tables=block_tables,
|
|
163
|
+
seq_lens=seq_lens,
|
|
164
|
+
query_start_loc=query_start_loc,
|
|
165
|
+
request_distribution=request_distribution,
|
|
166
|
+
)
|
|
167
|
+
if not self.runner.use_hybrid_kvcache:
|
|
168
|
+
# all layers share the same attention metadata
|
|
169
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
170
|
+
else:
|
|
171
|
+
for layer_name in kv_cache_group.layer_names:
|
|
172
|
+
attention_metadata_per_layer[
|
|
173
|
+
layer_name] = attention_metadata_gid
|
|
162
174
|
|
|
163
175
|
def model_fn_wrapper(
|
|
164
176
|
state,
|
|
165
177
|
kv_caches,
|
|
166
178
|
input_ids,
|
|
167
179
|
attention_metadata,
|
|
180
|
+
positions,
|
|
168
181
|
inputs_embeds,
|
|
169
182
|
layer_name_to_kvcache_index,
|
|
170
183
|
lora_metadata,
|
|
171
184
|
):
|
|
172
185
|
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
173
186
|
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
174
|
-
layer_name_to_kvcache_index, lora_metadata)
|
|
187
|
+
positions, layer_name_to_kvcache_index, lora_metadata)
|
|
175
188
|
self.runner.kv_caches = kv_caches
|
|
176
189
|
return hidden_states
|
|
177
190
|
|
|
@@ -179,6 +192,10 @@ class CompilationManager:
|
|
|
179
192
|
self.runner.lora_config, np.array([num_tokens],
|
|
180
193
|
dtype=np.int32)):
|
|
181
194
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
195
|
+
if self.runner.use_hybrid_kvcache:
|
|
196
|
+
attention_metadata = attention_metadata_per_layer
|
|
197
|
+
else:
|
|
198
|
+
attention_metadata = uniform_attention_metadata
|
|
182
199
|
self._run_compilation(
|
|
183
200
|
name,
|
|
184
201
|
model_fn_wrapper,
|
|
@@ -186,6 +203,7 @@ class CompilationManager:
|
|
|
186
203
|
self.runner.kv_caches,
|
|
187
204
|
input_ids,
|
|
188
205
|
attention_metadata,
|
|
206
|
+
positions,
|
|
189
207
|
inputs_embeds,
|
|
190
208
|
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
191
209
|
lora_metadata,
|
|
@@ -332,13 +350,15 @@ class CompilationManager:
|
|
|
332
350
|
index_paddings = self.runner.num_reqs_paddings
|
|
333
351
|
dp_sharding = NamedSharding(self.runner.mesh,
|
|
334
352
|
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
353
|
+
hidden_states_sharding = NamedSharding(
|
|
354
|
+
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
|
|
335
355
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
336
356
|
self._precompile_select_from_array_helper(
|
|
337
357
|
name="select all logits",
|
|
338
358
|
source_paddings=self.runner.num_tokens_paddings,
|
|
339
359
|
indices_paddings=index_paddings,
|
|
340
360
|
hidden_dim=hsize,
|
|
341
|
-
input_sharding=
|
|
361
|
+
input_sharding=hidden_states_sharding,
|
|
342
362
|
indices_sharding=dp_sharding if dp_size > 1 else None,
|
|
343
363
|
)
|
|
344
364
|
|
tpu_inference/runner/kv_cache.py
CHANGED
|
@@ -9,7 +9,7 @@ from torchax.ops.mappings import t2j_dtype
|
|
|
9
9
|
|
|
10
10
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
|
|
11
11
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
|
|
12
|
-
from tpu_inference.layers.
|
|
12
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
13
13
|
from tpu_inference.logger import init_logger
|
|
14
14
|
|
|
15
15
|
logger = init_logger(__name__)
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import functools
|
|
2
|
-
import math
|
|
3
2
|
from typing import TYPE_CHECKING, Dict, List
|
|
4
3
|
|
|
5
4
|
import jax
|
|
6
5
|
import jax.numpy as jnp
|
|
6
|
+
import numpy as np
|
|
7
7
|
import vllm.envs as envs
|
|
8
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
9
|
from torchax.ops.mappings import t2j_dtype
|
|
10
10
|
from vllm.attention import Attention
|
|
11
11
|
from vllm.attention.backends.abstract import AttentionType
|
|
12
12
|
from vllm.config import get_layers_from_vllm_config
|
|
13
|
+
from vllm.utils.math_utils import cdiv
|
|
13
14
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
14
15
|
KVCacheSpec, MLAAttentionSpec,
|
|
15
16
|
SlidingWindowSpec)
|
|
@@ -175,6 +176,11 @@ class KVCacheManager:
|
|
|
175
176
|
)
|
|
176
177
|
self.runner.input_batch = new_input_batch
|
|
177
178
|
self.runner.persistent_batch_manager.input_batch = new_input_batch
|
|
179
|
+
self.runner.block_tables_cpu = [
|
|
180
|
+
np.zeros((self.runner.max_num_reqs,
|
|
181
|
+
cdiv(self.runner.max_model_len, block_size)),
|
|
182
|
+
dtype=np.int32) for block_size in block_sizes
|
|
183
|
+
]
|
|
178
184
|
|
|
179
185
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
180
186
|
self.maybe_reinitialize_input_batch(kv_cache_config)
|
|
@@ -190,7 +196,7 @@ class KVCacheManager:
|
|
|
190
196
|
num_blocks = kv_cache_tensor.size // page_size_bytes
|
|
191
197
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
192
198
|
# num_blocks must be a multiple of dp_size
|
|
193
|
-
num_blocks =
|
|
199
|
+
num_blocks = (num_blocks // dp_size) * dp_size
|
|
194
200
|
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
195
201
|
kv_cache = create_kv_caches(
|
|
196
202
|
num_blocks=num_blocks,
|
|
@@ -27,7 +27,7 @@ from vllm.v1.core.sched.output import GrammarOutput
|
|
|
27
27
|
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
28
28
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
29
29
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
|
30
|
-
DraftTokenIds, KVConnectorOutput,
|
|
30
|
+
DraftTokenIds, KVConnectorOutput, LogprobsLists,
|
|
31
31
|
ModelRunnerOutput)
|
|
32
32
|
from vllm.v1.request import Request
|
|
33
33
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
@@ -37,15 +37,15 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
|
37
37
|
|
|
38
38
|
from tpu_inference import utils as common_utils
|
|
39
39
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
|
+
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
41
|
+
MESH_AXIS_NAMES_2D,
|
|
42
|
+
ShardingAxisName,
|
|
43
|
+
ShardingConfigManager)
|
|
40
44
|
from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
|
|
41
45
|
from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
|
|
42
46
|
gather_logprobs, sample)
|
|
43
47
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
44
48
|
TPUSupportedSamplingMetadata
|
|
45
|
-
from tpu_inference.layers.jax.sharding import (MESH_AXIS_NAMES,
|
|
46
|
-
MESH_AXIS_NAMES_2D,
|
|
47
|
-
ShardingAxisName,
|
|
48
|
-
ShardingConfigManager)
|
|
49
49
|
from tpu_inference.logger import init_logger
|
|
50
50
|
from tpu_inference.models.common.model_loader import get_model
|
|
51
51
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
@@ -153,6 +153,7 @@ class ExecuteModelState:
|
|
|
153
153
|
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
|
154
154
|
kv_connector_output: Optional[KVConnectorOutput]
|
|
155
155
|
logits_indices_selector: Optional[List[int]] = None
|
|
156
|
+
padded_num_reqs: Optional[int] = None
|
|
156
157
|
|
|
157
158
|
|
|
158
159
|
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
|
|
@@ -190,12 +191,40 @@ def _substitute_placeholder_token(
|
|
|
190
191
|
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
|
|
191
192
|
|
|
192
193
|
|
|
194
|
+
def _jax_logprobs_to_lists(logprobs_tensors,
|
|
195
|
+
logits_indices_selector=None,
|
|
196
|
+
cu_num_generated_tokens=None):
|
|
197
|
+
"""Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
|
|
198
|
+
log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
|
|
199
|
+
logprobs_list = logprobs_tensors.logprobs.tolist()
|
|
200
|
+
selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
|
|
201
|
+
|
|
202
|
+
if logits_indices_selector is not None:
|
|
203
|
+
log_token_ids_list = [
|
|
204
|
+
log_token_ids_list[i] for i in logits_indices_selector
|
|
205
|
+
]
|
|
206
|
+
logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
|
|
207
|
+
selected_token_ranks_list = [
|
|
208
|
+
selected_token_ranks_list[i] for i in logits_indices_selector
|
|
209
|
+
]
|
|
210
|
+
|
|
211
|
+
return LogprobsLists(
|
|
212
|
+
logprob_token_ids=np.asarray(log_token_ids_list),
|
|
213
|
+
logprobs=np.asarray(logprobs_list),
|
|
214
|
+
sampled_token_ranks=np.asarray(selected_token_ranks_list),
|
|
215
|
+
cu_num_generated_tokens=cu_num_generated_tokens,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
193
219
|
class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
194
220
|
|
|
195
221
|
def __init__(
|
|
196
222
|
self,
|
|
197
223
|
vllm_config: VllmConfig,
|
|
198
224
|
devices: List[Any],
|
|
225
|
+
rank: int = 0,
|
|
226
|
+
is_first_rank: bool = True,
|
|
227
|
+
is_last_rank: bool = True,
|
|
199
228
|
):
|
|
200
229
|
self.vllm_config = vllm_config
|
|
201
230
|
self.model_config = vllm_config.model_config
|
|
@@ -408,8 +437,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
408
437
|
|
|
409
438
|
self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
410
439
|
self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
411
|
-
self.
|
|
412
|
-
|
|
440
|
+
# Note: self.input_batch and self.block_tables_cpu are both initialized
|
|
441
|
+
# with only 1 block_size. For hybrid kv cache, it will be re-init
|
|
442
|
+
# in kv_cache_manager's maybe_reinitialize_input_batch.
|
|
443
|
+
self.block_tables_cpu = [
|
|
444
|
+
np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
|
|
445
|
+
dtype=np.int32)
|
|
446
|
+
]
|
|
447
|
+
|
|
413
448
|
self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
|
|
414
449
|
dtype=np.int32)
|
|
415
450
|
self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
|
|
@@ -443,9 +478,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
443
478
|
|
|
444
479
|
# tensors for structured decoding
|
|
445
480
|
self.vocab_size = self.model_config.get_vocab_size()
|
|
446
|
-
if self.lora_config is not None:
|
|
447
|
-
# lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
|
|
448
|
-
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
|
449
481
|
self.grammar_bitmask_cpu = np.zeros(
|
|
450
482
|
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
|
|
451
483
|
dtype=np.int32,
|
|
@@ -490,9 +522,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
490
522
|
|
|
491
523
|
self.rng_params_for_sampling = nnx.Rngs(
|
|
492
524
|
jax.random.key(self.model_config.seed)).params()
|
|
493
|
-
self.is_multimodal_model = (
|
|
494
|
-
|
|
495
|
-
|
|
525
|
+
self.is_multimodal_model = (
|
|
526
|
+
self.model_config.is_multimodal_model
|
|
527
|
+
and self.get_multimodal_embeddings_fn is not None and hasattr(
|
|
528
|
+
self.model_config.hf_config, "architectures"
|
|
529
|
+
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
530
|
+
and len(self.model_config.hf_config.architectures) >= 1
|
|
531
|
+
and self.model_config.hf_config.architectures[0]
|
|
532
|
+
!= "Llama4ForConditionalGeneration")
|
|
496
533
|
|
|
497
534
|
logger.info(f"Init model | "
|
|
498
535
|
f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
|
|
@@ -505,6 +542,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
505
542
|
|
|
506
543
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
507
544
|
self.kv_cache_config = kv_cache_config
|
|
545
|
+
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
|
|
508
546
|
self.kv_caches = []
|
|
509
547
|
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
|
|
510
548
|
if has_kv_transfer_group():
|
|
@@ -535,16 +573,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
535
573
|
|
|
536
574
|
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
537
575
|
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
538
|
-
logits_indices_selector
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
576
|
+
logits_indices_selector,
|
|
577
|
+
padded_num_reqs) = (self.execute_model_state.scheduler_output,
|
|
578
|
+
self.execute_model_state.attn_metadata,
|
|
579
|
+
self.execute_model_state.input_ids,
|
|
580
|
+
self.execute_model_state.hidden_states,
|
|
581
|
+
self.execute_model_state.logits,
|
|
582
|
+
self.execute_model_state.aux_hidden_states,
|
|
583
|
+
self.execute_model_state.spec_decode_metadata,
|
|
584
|
+
self.execute_model_state.kv_connector_output,
|
|
585
|
+
self.execute_model_state.logits_indices_selector,
|
|
586
|
+
self.execute_model_state.padded_num_reqs)
|
|
548
587
|
self.execute_model_state = None
|
|
549
588
|
|
|
550
589
|
if grammar_output is not None:
|
|
@@ -558,12 +597,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
558
597
|
logits,
|
|
559
598
|
arange,
|
|
560
599
|
)
|
|
561
|
-
return self._sample_from_logits(
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
kv_connector_output,
|
|
566
|
-
logits_indices_selector)
|
|
600
|
+
return self._sample_from_logits(
|
|
601
|
+
scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
602
|
+
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
603
|
+
logits_indices_selector, padded_num_reqs)
|
|
567
604
|
|
|
568
605
|
def _modify_prev_results(self):
|
|
569
606
|
# If copy to host has not been done, we just wait.
|
|
@@ -672,13 +709,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
672
709
|
# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
|
|
673
710
|
(
|
|
674
711
|
input_ids,
|
|
712
|
+
input_positions,
|
|
675
713
|
attn_metadata,
|
|
676
714
|
_,
|
|
677
715
|
logits_indices,
|
|
678
716
|
spec_decode_metadata,
|
|
679
717
|
logits_indices_selector,
|
|
718
|
+
padded_num_reqs,
|
|
680
719
|
) = self._prepare_inputs(scheduler_output)
|
|
681
720
|
|
|
721
|
+
is_llama_guard_4 = (
|
|
722
|
+
hasattr(
|
|
723
|
+
self.model_config.hf_config, "architectures"
|
|
724
|
+
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
725
|
+
and len(self.model_config.hf_config.architectures) >= 1
|
|
726
|
+
and self.model_config.hf_config.architectures[0]
|
|
727
|
+
== "Llama4ForConditionalGeneration")
|
|
728
|
+
|
|
682
729
|
# multi-modal support
|
|
683
730
|
if self.is_multimodal_model:
|
|
684
731
|
# Run the multimodal encoder if any.
|
|
@@ -686,6 +733,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
686
733
|
self.mm_manager.execute_mm_encoder(scheduler_output)
|
|
687
734
|
mm_embeds = self.mm_manager.gather_mm_embeddings(
|
|
688
735
|
scheduler_output, input_ids.shape[0])
|
|
736
|
+
#TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
|
|
737
|
+
elif is_llama_guard_4 and any(
|
|
738
|
+
self.mm_manager.runner.requests[req_id].mm_features
|
|
739
|
+
for req_id in self.mm_manager.runner.input_batch.req_ids):
|
|
740
|
+
raise NotImplementedError(
|
|
741
|
+
"Llama Guard 4 (JAX) currently supports only text inputs. "
|
|
742
|
+
"Multimodal processing not yet implemented.")
|
|
689
743
|
else:
|
|
690
744
|
mm_embeds = []
|
|
691
745
|
|
|
@@ -718,6 +772,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
718
772
|
input_ids,
|
|
719
773
|
attn_metadata,
|
|
720
774
|
inputs_embeds,
|
|
775
|
+
input_positions,
|
|
721
776
|
tuple(self.layer_name_to_kvcache_index.items()),
|
|
722
777
|
lora_metadata,
|
|
723
778
|
)
|
|
@@ -739,7 +794,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
739
794
|
aux_hidden_states=aux_hidden_states,
|
|
740
795
|
spec_decode_metadata=spec_decode_metadata,
|
|
741
796
|
kv_connector_output=kv_connector_output,
|
|
742
|
-
logits_indices_selector=logits_indices_selector
|
|
797
|
+
logits_indices_selector=logits_indices_selector,
|
|
798
|
+
padded_num_reqs=padded_num_reqs)
|
|
743
799
|
return attn_metadata, None
|
|
744
800
|
|
|
745
801
|
def _sample_from_logits(
|
|
@@ -753,11 +809,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
753
809
|
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
|
754
810
|
kv_connector_output: Optional[KVConnectorOutput],
|
|
755
811
|
logits_indices_selector: Optional[List[int]] = None,
|
|
812
|
+
padded_num_reqs: Optional[int] = None,
|
|
756
813
|
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
|
|
757
|
-
padded_num_reqs
|
|
758
|
-
|
|
814
|
+
if padded_num_reqs is None:
|
|
815
|
+
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
|
|
816
|
+
self.input_batch.num_reqs, self.max_num_reqs)
|
|
817
|
+
|
|
818
|
+
sharding = None
|
|
819
|
+
if self.dp_size > 1:
|
|
820
|
+
sharding = NamedSharding(self.mesh,
|
|
821
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
822
|
+
|
|
759
823
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
760
|
-
self.mesh, self.input_batch, padded_num_reqs)
|
|
824
|
+
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
|
|
761
825
|
if spec_decode_metadata is None:
|
|
762
826
|
next_tokens = sample(
|
|
763
827
|
self.rng_params_for_sampling,
|
|
@@ -840,7 +904,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
840
904
|
logits_indices_selector)
|
|
841
905
|
|
|
842
906
|
if logprobs is not None:
|
|
843
|
-
|
|
907
|
+
# Map logprobs back to the pre-dp shuffling order
|
|
908
|
+
logprobs_lists = _jax_logprobs_to_lists(
|
|
909
|
+
logprobs, logits_indices_selector)
|
|
910
|
+
|
|
844
911
|
else:
|
|
845
912
|
logprobs_lists = None
|
|
846
913
|
|
|
@@ -908,7 +975,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
908
975
|
req_state.output_token_ids.extend(sampled_ids)
|
|
909
976
|
|
|
910
977
|
if logprobs is not None:
|
|
911
|
-
|
|
978
|
+
# Map logprobs back to the pre-dp shuffling order
|
|
979
|
+
logprobs_lists = _jax_logprobs_to_lists(logprobs,
|
|
980
|
+
logits_indices_selector)
|
|
912
981
|
else:
|
|
913
982
|
logprobs_lists = None
|
|
914
983
|
|
|
@@ -1256,16 +1325,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1256
1325
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1257
1326
|
padded_total_num_scheduled_tokens]
|
|
1258
1327
|
|
|
1259
|
-
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1260
|
-
for dp_rank in range(dp_size):
|
|
1261
|
-
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1262
|
-
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1263
|
-
|
|
1264
|
-
block_tables[
|
|
1265
|
-
req_offset:req_offset + _num_reqs, :self.
|
|
1266
|
-
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1267
|
-
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1268
|
-
|
|
1269
1328
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
|
|
1270
1329
|
dp_size]
|
|
1271
1330
|
seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
|
|
@@ -1307,20 +1366,59 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1307
1366
|
if self.uses_mrope:
|
|
1308
1367
|
positions = mrope_positions
|
|
1309
1368
|
|
|
1310
|
-
# Convert block_tables to 1D on cpu.
|
|
1311
|
-
block_tables = block_tables.reshape(-1)
|
|
1312
|
-
|
|
1313
1369
|
query_start_loc_cpu = query_start_loc
|
|
1314
1370
|
logits_indices_cpu = logits_indices
|
|
1315
1371
|
seq_lens_cpu = seq_lens
|
|
1316
1372
|
|
|
1317
|
-
(input_ids, positions,
|
|
1318
|
-
|
|
1373
|
+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
|
|
1374
|
+
request_distribution) = device_array(
|
|
1319
1375
|
self.mesh,
|
|
1320
|
-
(input_ids, positions,
|
|
1321
|
-
|
|
1376
|
+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
|
|
1377
|
+
request_distribution),
|
|
1322
1378
|
sharding=data_parallel_attn_sharding,
|
|
1323
1379
|
)
|
|
1380
|
+
|
|
1381
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1382
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
1383
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1384
|
+
self.kv_cache_config.kv_cache_groups):
|
|
1385
|
+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1386
|
+
max_num_reqs]
|
|
1387
|
+
for dp_rank in range(dp_size):
|
|
1388
|
+
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1389
|
+
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1390
|
+
|
|
1391
|
+
block_tables[
|
|
1392
|
+
req_offset:req_offset + _num_reqs, :self.
|
|
1393
|
+
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1394
|
+
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1395
|
+
# Convert block_tables to 1D on cpu.
|
|
1396
|
+
block_tables = block_tables.reshape(-1)
|
|
1397
|
+
block_tables = device_array(
|
|
1398
|
+
self.mesh,
|
|
1399
|
+
(block_tables),
|
|
1400
|
+
sharding=data_parallel_attn_sharding,
|
|
1401
|
+
)
|
|
1402
|
+
|
|
1403
|
+
attention_metadata_gid = AttentionMetadata(
|
|
1404
|
+
input_positions=positions,
|
|
1405
|
+
block_tables=block_tables,
|
|
1406
|
+
seq_lens=seq_lens,
|
|
1407
|
+
query_start_loc=query_start_loc,
|
|
1408
|
+
request_distribution=request_distribution,
|
|
1409
|
+
)
|
|
1410
|
+
|
|
1411
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1412
|
+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1413
|
+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1414
|
+
|
|
1415
|
+
if not self.use_hybrid_kvcache:
|
|
1416
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
1417
|
+
else:
|
|
1418
|
+
for layer_name in kv_cache_group.layer_names:
|
|
1419
|
+
attention_metadata_per_layer[
|
|
1420
|
+
layer_name] = attention_metadata_gid
|
|
1421
|
+
|
|
1324
1422
|
# Async scheduling: substitute placeholder tokens for DP
|
|
1325
1423
|
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
|
|
1326
1424
|
# Collect all token indices that need substitution across all DP ranks
|
|
@@ -1349,25 +1447,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1349
1447
|
padded_total_num_scheduled_tokens,
|
|
1350
1448
|
)
|
|
1351
1449
|
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
query_start_loc=query_start_loc,
|
|
1357
|
-
request_distribution=request_distribution,
|
|
1358
|
-
)
|
|
1359
|
-
|
|
1360
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1361
|
-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1362
|
-
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1363
|
-
|
|
1450
|
+
if self.use_hybrid_kvcache:
|
|
1451
|
+
attention_metadata = attention_metadata_per_layer
|
|
1452
|
+
else:
|
|
1453
|
+
attention_metadata = uniform_attention_metadata
|
|
1364
1454
|
return (
|
|
1365
1455
|
input_ids,
|
|
1456
|
+
positions,
|
|
1366
1457
|
attention_metadata,
|
|
1367
1458
|
sampling_metadata,
|
|
1368
1459
|
logits_indices,
|
|
1369
1460
|
spec_decode_metadata,
|
|
1370
1461
|
logits_indices_selector,
|
|
1462
|
+
padded_num_reqs,
|
|
1371
1463
|
)
|
|
1372
1464
|
|
|
1373
1465
|
def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
|
|
@@ -1468,9 +1560,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1468
1560
|
positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
|
|
1469
1561
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1470
1562
|
padded_total_num_scheduled_tokens]
|
|
1471
|
-
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1472
|
-
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
|
1473
|
-
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
|
1474
1563
|
|
|
1475
1564
|
# TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
|
|
1476
1565
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
|
|
@@ -1499,16 +1588,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1499
1588
|
self.mesh, self.input_batch, padded_num_reqs)
|
|
1500
1589
|
if self.uses_mrope:
|
|
1501
1590
|
positions = mrope_positions
|
|
1502
|
-
|
|
1503
|
-
# Convert block_tables to 1D on cpu.
|
|
1504
|
-
block_tables = block_tables.reshape(-1)
|
|
1505
|
-
|
|
1506
1591
|
query_start_loc_cpu = query_start_loc
|
|
1507
1592
|
seq_lens_cpu = seq_lens
|
|
1508
|
-
|
|
1593
|
+
|
|
1594
|
+
(input_ids, positions, query_start_loc, seq_lens,
|
|
1509
1595
|
logits_indices, request_distribution) = device_array(
|
|
1510
|
-
self.mesh, (input_ids, positions,
|
|
1511
|
-
|
|
1596
|
+
self.mesh, (input_ids, positions, query_start_loc, seq_lens,
|
|
1597
|
+
logits_indices, request_distribution))
|
|
1598
|
+
|
|
1599
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1600
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
1601
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1602
|
+
self.kv_cache_config.kv_cache_groups):
|
|
1603
|
+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1604
|
+
max_num_reqs]
|
|
1605
|
+
block_tables[:num_reqs] = (
|
|
1606
|
+
self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
|
|
1607
|
+
[:num_reqs])
|
|
1608
|
+
# Convert block_tables to 1D on cpu.
|
|
1609
|
+
block_tables = block_tables.reshape(-1)
|
|
1610
|
+
block_tables = device_array(self.mesh, (block_tables))
|
|
1611
|
+
|
|
1612
|
+
attention_metadata_gid = AttentionMetadata(
|
|
1613
|
+
input_positions=positions,
|
|
1614
|
+
block_tables=block_tables,
|
|
1615
|
+
seq_lens=seq_lens,
|
|
1616
|
+
query_start_loc=query_start_loc,
|
|
1617
|
+
request_distribution=request_distribution)
|
|
1618
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1619
|
+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1620
|
+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1621
|
+
|
|
1622
|
+
if not self.use_hybrid_kvcache:
|
|
1623
|
+
# all layers share the same attention metadata
|
|
1624
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
1625
|
+
else:
|
|
1626
|
+
for layer_name in kv_cache_group.layer_names:
|
|
1627
|
+
attention_metadata_per_layer[
|
|
1628
|
+
layer_name] = attention_metadata_gid
|
|
1512
1629
|
|
|
1513
1630
|
if self.scheduler_config.async_scheduling and len(
|
|
1514
1631
|
token_in_tpu_cur_input_indices) > 0:
|
|
@@ -1521,20 +1638,15 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1521
1638
|
self.lora_utils.set_active_loras(
|
|
1522
1639
|
num_scheduled_tokens_per_req, total_num_scheduled_tokens,
|
|
1523
1640
|
padded_total_num_scheduled_tokens)
|
|
1524
|
-
|
|
1525
|
-
attention_metadata = AttentionMetadata(
|
|
1526
|
-
input_positions=positions,
|
|
1527
|
-
block_tables=block_tables,
|
|
1528
|
-
seq_lens=seq_lens,
|
|
1529
|
-
query_start_loc=query_start_loc,
|
|
1530
|
-
request_distribution=request_distribution)
|
|
1531
|
-
|
|
1532
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1533
|
-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1534
|
-
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1535
1641
|
logits_indices_selector = None
|
|
1536
|
-
|
|
1537
|
-
|
|
1642
|
+
|
|
1643
|
+
if self.use_hybrid_kvcache:
|
|
1644
|
+
attention_metadata = attention_metadata_per_layer
|
|
1645
|
+
else:
|
|
1646
|
+
attention_metadata = uniform_attention_metadata
|
|
1647
|
+
return (input_ids, positions, attention_metadata, sampling_metadata,
|
|
1648
|
+
logits_indices, spec_decode_metadata, logits_indices_selector,
|
|
1649
|
+
padded_num_reqs)
|
|
1538
1650
|
|
|
1539
1651
|
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1540
1652
|
mm_embeds: list[jax.Array]):
|
|
@@ -51,7 +51,8 @@ class Eagle3Proposer:
|
|
|
51
51
|
"""Loads the draft model."""
|
|
52
52
|
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
|
|
53
53
|
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
|
|
54
|
-
|
|
54
|
+
if 'embed_tokens' in self.state.model:
|
|
55
|
+
del self.state.model['embed_tokens']
|
|
55
56
|
self.state.model.embed_tokens = target_model.model.embed
|
|
56
57
|
|
|
57
58
|
@functools.partial(jax.jit, static_argnums=(0, ))
|