tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import functools
|
|
3
|
-
import
|
|
3
|
+
import logging
|
|
4
4
|
import random
|
|
5
5
|
from contextlib import nullcontext
|
|
6
6
|
from dataclasses import dataclass
|
|
@@ -10,17 +10,15 @@ import jax
|
|
|
10
10
|
import jax.numpy as jnp
|
|
11
11
|
import jaxtyping
|
|
12
12
|
import numpy as np
|
|
13
|
-
import
|
|
14
|
-
import vllm.envs as envs
|
|
13
|
+
import vllm.envs as vllm_envs
|
|
15
14
|
from flax import nnx
|
|
16
15
|
from jax.experimental import mesh_utils
|
|
17
16
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
18
|
-
from torchax.ops.mappings import j2t_dtype
|
|
19
17
|
from vllm.config import VllmConfig
|
|
18
|
+
from vllm.distributed import get_pp_group
|
|
20
19
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
21
20
|
has_kv_transfer_group)
|
|
22
21
|
from vllm.forward_context import set_forward_context
|
|
23
|
-
from vllm.sequence import IntermediateTensors
|
|
24
22
|
from vllm.tasks import SupportedTask
|
|
25
23
|
from vllm.utils.math_utils import cdiv
|
|
26
24
|
from vllm.v1.core.sched.output import GrammarOutput
|
|
@@ -35,6 +33,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
|
35
33
|
KVConnectorModelRunnerMixin
|
|
36
34
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
37
35
|
|
|
36
|
+
import tpu_inference.envs as envs
|
|
38
37
|
from tpu_inference import utils as common_utils
|
|
39
38
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
39
|
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
@@ -48,6 +47,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
|
48
47
|
TPUSupportedSamplingMetadata
|
|
49
48
|
from tpu_inference.logger import init_logger
|
|
50
49
|
from tpu_inference.models.common.model_loader import get_model
|
|
50
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
51
|
+
JaxIntermediateTensors
|
|
51
52
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
52
53
|
shard_put, transfer_state_with_mappings)
|
|
53
54
|
from tpu_inference.runner import utils as runner_utils
|
|
@@ -64,10 +65,12 @@ from tpu_inference.runner.structured_decoding_manager import \
|
|
|
64
65
|
StructuredDecodingManager
|
|
65
66
|
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
66
67
|
from tpu_inference.utils import (device_array, make_optimized_mesh,
|
|
67
|
-
time_function)
|
|
68
|
+
time_function, to_jax_dtype, to_torch_dtype)
|
|
68
69
|
|
|
69
70
|
logger = init_logger(__name__)
|
|
70
71
|
|
|
72
|
+
logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
|
|
73
|
+
|
|
71
74
|
INVALID_TOKEN_ID = -1
|
|
72
75
|
# Smallest output size
|
|
73
76
|
MIN_NUM_SEQS = 8
|
|
@@ -78,17 +81,6 @@ DUMMY_METADATA = AttentionMetadata(
|
|
|
78
81
|
request_distribution=[0, 0, 0],
|
|
79
82
|
)
|
|
80
83
|
|
|
81
|
-
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
82
|
-
"half": torch.half,
|
|
83
|
-
"bfloat16": torch.bfloat16,
|
|
84
|
-
"float": torch.float,
|
|
85
|
-
"fp8": torch.float8_e4m3fn,
|
|
86
|
-
"fp8_e4m3": torch.float8_e4m3fn,
|
|
87
|
-
"fp8_e5m2": torch.float8_e5m2,
|
|
88
|
-
"int8": torch.int8,
|
|
89
|
-
"uint8": torch.uint8,
|
|
90
|
-
}
|
|
91
|
-
|
|
92
84
|
|
|
93
85
|
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
94
86
|
"""Holds asynchronous model output specifically from a TPU runner.
|
|
@@ -153,6 +145,7 @@ class ExecuteModelState:
|
|
|
153
145
|
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
|
154
146
|
kv_connector_output: Optional[KVConnectorOutput]
|
|
155
147
|
logits_indices_selector: Optional[List[int]] = None
|
|
148
|
+
padded_num_reqs: Optional[int] = None
|
|
156
149
|
|
|
157
150
|
|
|
158
151
|
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
|
|
@@ -190,18 +183,28 @@ def _substitute_placeholder_token(
|
|
|
190
183
|
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
|
|
191
184
|
|
|
192
185
|
|
|
193
|
-
def
|
|
186
|
+
def _jax_logprobs_to_lists(logprobs_tensors,
|
|
187
|
+
logits_indices_selector=None,
|
|
188
|
+
cu_num_generated_tokens=None):
|
|
189
|
+
"""Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
|
|
190
|
+
log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
|
|
191
|
+
logprobs_list = logprobs_tensors.logprobs.tolist()
|
|
192
|
+
selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
|
|
193
|
+
|
|
194
|
+
if logits_indices_selector is not None:
|
|
195
|
+
log_token_ids_list = [
|
|
196
|
+
log_token_ids_list[i] for i in logits_indices_selector
|
|
197
|
+
]
|
|
198
|
+
logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
|
|
199
|
+
selected_token_ranks_list = [
|
|
200
|
+
selected_token_ranks_list[i] for i in logits_indices_selector
|
|
201
|
+
]
|
|
202
|
+
|
|
194
203
|
return LogprobsLists(
|
|
195
|
-
logprob_token_ids=
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
|
|
200
|
-
sampled_token_ranks=[
|
|
201
|
-
logprobs_lists.sampled_token_ranks[i]
|
|
202
|
-
for i in logits_indices_selector
|
|
203
|
-
],
|
|
204
|
-
cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
|
|
204
|
+
logprob_token_ids=np.asarray(log_token_ids_list),
|
|
205
|
+
logprobs=np.asarray(logprobs_list),
|
|
206
|
+
sampled_token_ranks=np.asarray(selected_token_ranks_list),
|
|
207
|
+
cu_num_generated_tokens=cu_num_generated_tokens,
|
|
205
208
|
)
|
|
206
209
|
|
|
207
210
|
|
|
@@ -211,6 +214,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
211
214
|
self,
|
|
212
215
|
vllm_config: VllmConfig,
|
|
213
216
|
devices: List[Any],
|
|
217
|
+
rank: int = 0,
|
|
218
|
+
is_first_rank: bool = True,
|
|
219
|
+
is_last_rank: bool = True,
|
|
214
220
|
):
|
|
215
221
|
self.vllm_config = vllm_config
|
|
216
222
|
self.model_config = vllm_config.model_config
|
|
@@ -229,6 +235,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
229
235
|
self.maybe_forbid_compile = runner_utils.ForbidCompile(
|
|
230
236
|
) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
|
|
231
237
|
self.dp_size = self.vllm_config.sharding_config.total_dp_size
|
|
238
|
+
self.rank = rank
|
|
239
|
+
self.is_first_rank = is_first_rank
|
|
240
|
+
self.is_last_rank = is_last_rank
|
|
232
241
|
|
|
233
242
|
self._init_random()
|
|
234
243
|
self._init_mesh()
|
|
@@ -239,31 +248,21 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
239
248
|
|
|
240
249
|
# Delegate functions to specific manager classes.
|
|
241
250
|
self.compilation_manager = CompilationManager(self)
|
|
242
|
-
self.
|
|
243
|
-
|
|
251
|
+
if self.is_last_rank:
|
|
252
|
+
self.speculative_decoding_manager = SpeculativeDecodingManager(
|
|
253
|
+
self)
|
|
254
|
+
self.structured_decoding_manager = StructuredDecodingManager(self)
|
|
244
255
|
self.kv_cache_manager = KVCacheManager(self)
|
|
245
256
|
self.mm_manager = MultiModalManager(self)
|
|
246
257
|
self.persistent_batch_manager = PersistentBatchManager(
|
|
247
258
|
self.requests, self.input_batch, self.encoder_cache,
|
|
248
|
-
self.uses_mrope, self.model_config)
|
|
259
|
+
self.uses_mrope, self.model_config, self.is_last_rank)
|
|
249
260
|
self.lora_utils = LoraUtils(self)
|
|
250
261
|
|
|
251
|
-
|
|
252
|
-
if
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
|
256
|
-
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
|
|
257
|
-
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
|
|
258
|
-
elif isinstance(model_dtype, torch.dtype):
|
|
259
|
-
self.kv_cache_dtype = model_dtype
|
|
260
|
-
else:
|
|
261
|
-
raise ValueError(
|
|
262
|
-
"KV cache is unsupported for model_dtype of %s",
|
|
263
|
-
model_dtype)
|
|
264
|
-
else:
|
|
265
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
|
|
266
|
-
cache_config.cache_dtype]
|
|
262
|
+
cache_dtype = self.cache_config.cache_dtype
|
|
263
|
+
if cache_dtype == "auto":
|
|
264
|
+
cache_dtype = self.dtype
|
|
265
|
+
self.kv_cache_dtype = to_torch_dtype(cache_dtype)
|
|
267
266
|
|
|
268
267
|
self._pre_async_results: AsyncPreResults | None = None
|
|
269
268
|
self._substitute_placeholder_token_fn = _substitute_placeholder_token
|
|
@@ -277,7 +276,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
277
276
|
self.rng_key = jax.random.key(self.model_config.seed)
|
|
278
277
|
|
|
279
278
|
def _init_mesh(self) -> None:
|
|
280
|
-
if
|
|
279
|
+
if envs.NEW_MODEL_DESIGN:
|
|
281
280
|
self.mesh = self._create_new_model_mesh()
|
|
282
281
|
else:
|
|
283
282
|
# NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
|
|
@@ -288,7 +287,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
288
287
|
logger.info(f"Init mesh | mesh={self.mesh}")
|
|
289
288
|
|
|
290
289
|
def _create_new_model_mesh(self) -> jax.sharding.Mesh:
|
|
291
|
-
num_slices =
|
|
290
|
+
num_slices = envs.NUM_SLICES
|
|
292
291
|
|
|
293
292
|
logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
|
|
294
293
|
f"num_slices={num_slices}")
|
|
@@ -357,7 +356,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
357
356
|
devices=self.devices)
|
|
358
357
|
|
|
359
358
|
def _init_phased_profiling(self) -> None:
|
|
360
|
-
self.phased_profiling_dir =
|
|
359
|
+
self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
|
|
361
360
|
self.phase_based_profiler = None
|
|
362
361
|
if self.phased_profiling_dir:
|
|
363
362
|
self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
|
|
@@ -399,7 +398,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
399
398
|
min_token_size=max(16, self.dp_size),
|
|
400
399
|
max_token_size=scheduler_config.max_num_batched_tokens *
|
|
401
400
|
self.dp_size,
|
|
402
|
-
padding_gap=
|
|
401
|
+
padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
|
403
402
|
self.num_tokens_paddings_per_dp = [
|
|
404
403
|
padding // self.dp_size for padding in self.num_tokens_paddings
|
|
405
404
|
]
|
|
@@ -423,8 +422,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
423
422
|
|
|
424
423
|
self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
425
424
|
self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
426
|
-
self.
|
|
427
|
-
|
|
425
|
+
# Note: self.input_batch and self.block_tables_cpu are both initialized
|
|
426
|
+
# with only 1 block_size. For hybrid kv cache, it will be re-init
|
|
427
|
+
# in kv_cache_manager's maybe_reinitialize_input_batch.
|
|
428
|
+
self.block_tables_cpu = [
|
|
429
|
+
np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
|
|
430
|
+
dtype=np.int32)
|
|
431
|
+
]
|
|
432
|
+
|
|
428
433
|
self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
|
|
429
434
|
dtype=np.int32)
|
|
430
435
|
self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
|
|
@@ -458,9 +463,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
458
463
|
|
|
459
464
|
# tensors for structured decoding
|
|
460
465
|
self.vocab_size = self.model_config.get_vocab_size()
|
|
461
|
-
if self.lora_config is not None:
|
|
462
|
-
# lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
|
|
463
|
-
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
|
464
466
|
self.grammar_bitmask_cpu = np.zeros(
|
|
465
467
|
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
|
|
466
468
|
dtype=np.int32,
|
|
@@ -505,9 +507,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
505
507
|
|
|
506
508
|
self.rng_params_for_sampling = nnx.Rngs(
|
|
507
509
|
jax.random.key(self.model_config.seed)).params()
|
|
508
|
-
self.is_multimodal_model = (
|
|
509
|
-
|
|
510
|
-
|
|
510
|
+
self.is_multimodal_model = (
|
|
511
|
+
self.model_config.is_multimodal_model
|
|
512
|
+
and self.get_multimodal_embeddings_fn is not None and hasattr(
|
|
513
|
+
self.model_config.hf_config, "architectures"
|
|
514
|
+
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
515
|
+
and len(self.model_config.hf_config.architectures) >= 1
|
|
516
|
+
and self.model_config.hf_config.architectures[0]
|
|
517
|
+
!= "Llama4ForConditionalGeneration")
|
|
511
518
|
|
|
512
519
|
logger.info(f"Init model | "
|
|
513
520
|
f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
|
|
@@ -520,6 +527,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
520
527
|
|
|
521
528
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
522
529
|
self.kv_cache_config = kv_cache_config
|
|
530
|
+
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
|
|
523
531
|
self.kv_caches = []
|
|
524
532
|
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
|
|
525
533
|
if has_kv_transfer_group():
|
|
@@ -532,12 +540,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
532
540
|
def execute_model(
|
|
533
541
|
self,
|
|
534
542
|
scheduler_output: "VllmSchedulerOutput",
|
|
535
|
-
intermediate_tensors: Optional[
|
|
536
|
-
) -> ModelRunnerOutput | None:
|
|
543
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
544
|
+
) -> ModelRunnerOutput | JaxIntermediateTensors | None:
|
|
537
545
|
if self.execute_model_state is not None:
|
|
538
546
|
raise RuntimeError("State error: sample_tokens() must be called "
|
|
539
547
|
"after execute_model() returns None.")
|
|
540
|
-
_, output = self._execute_model(scheduler_output)
|
|
548
|
+
_, output = self._execute_model(scheduler_output, intermediate_tensors)
|
|
541
549
|
return output
|
|
542
550
|
|
|
543
551
|
def sample_tokens(
|
|
@@ -550,16 +558,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
550
558
|
|
|
551
559
|
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
552
560
|
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
553
|
-
logits_indices_selector
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
561
|
+
logits_indices_selector,
|
|
562
|
+
padded_num_reqs) = (self.execute_model_state.scheduler_output,
|
|
563
|
+
self.execute_model_state.attn_metadata,
|
|
564
|
+
self.execute_model_state.input_ids,
|
|
565
|
+
self.execute_model_state.hidden_states,
|
|
566
|
+
self.execute_model_state.logits,
|
|
567
|
+
self.execute_model_state.aux_hidden_states,
|
|
568
|
+
self.execute_model_state.spec_decode_metadata,
|
|
569
|
+
self.execute_model_state.kv_connector_output,
|
|
570
|
+
self.execute_model_state.logits_indices_selector,
|
|
571
|
+
self.execute_model_state.padded_num_reqs)
|
|
563
572
|
self.execute_model_state = None
|
|
564
573
|
|
|
565
574
|
if grammar_output is not None:
|
|
@@ -573,12 +582,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
573
582
|
logits,
|
|
574
583
|
arange,
|
|
575
584
|
)
|
|
576
|
-
return self._sample_from_logits(
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
kv_connector_output,
|
|
581
|
-
logits_indices_selector)
|
|
585
|
+
return self._sample_from_logits(
|
|
586
|
+
scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
587
|
+
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
588
|
+
logits_indices_selector, padded_num_reqs)
|
|
582
589
|
|
|
583
590
|
def _modify_prev_results(self):
|
|
584
591
|
# If copy to host has not been done, we just wait.
|
|
@@ -664,7 +671,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
664
671
|
def _execute_model(
|
|
665
672
|
self,
|
|
666
673
|
scheduler_output: "VllmSchedulerOutput",
|
|
667
|
-
|
|
674
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
675
|
+
) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
|
|
676
|
+
| None]:
|
|
668
677
|
self.persistent_batch_manager.update_states(
|
|
669
678
|
scheduler_output, self.get_mrope_input_positions_fn)
|
|
670
679
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
@@ -687,13 +696,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
687
696
|
# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
|
|
688
697
|
(
|
|
689
698
|
input_ids,
|
|
699
|
+
input_positions,
|
|
690
700
|
attn_metadata,
|
|
691
701
|
_,
|
|
692
702
|
logits_indices,
|
|
693
703
|
spec_decode_metadata,
|
|
694
704
|
logits_indices_selector,
|
|
705
|
+
padded_num_reqs,
|
|
695
706
|
) = self._prepare_inputs(scheduler_output)
|
|
696
707
|
|
|
708
|
+
is_llama_guard_4 = (
|
|
709
|
+
hasattr(
|
|
710
|
+
self.model_config.hf_config, "architectures"
|
|
711
|
+
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
712
|
+
and len(self.model_config.hf_config.architectures) >= 1
|
|
713
|
+
and self.model_config.hf_config.architectures[0]
|
|
714
|
+
== "Llama4ForConditionalGeneration")
|
|
715
|
+
|
|
697
716
|
# multi-modal support
|
|
698
717
|
if self.is_multimodal_model:
|
|
699
718
|
# Run the multimodal encoder if any.
|
|
@@ -701,6 +720,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
701
720
|
self.mm_manager.execute_mm_encoder(scheduler_output)
|
|
702
721
|
mm_embeds = self.mm_manager.gather_mm_embeddings(
|
|
703
722
|
scheduler_output, input_ids.shape[0])
|
|
723
|
+
#TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
|
|
724
|
+
elif is_llama_guard_4 and any(
|
|
725
|
+
self.mm_manager.runner.requests[req_id].mm_features
|
|
726
|
+
for req_id in self.mm_manager.runner.input_batch.req_ids):
|
|
727
|
+
raise NotImplementedError(
|
|
728
|
+
"Llama Guard 4 (JAX) currently supports only text inputs. "
|
|
729
|
+
"Multimodal processing not yet implemented.")
|
|
704
730
|
else:
|
|
705
731
|
mm_embeds = []
|
|
706
732
|
|
|
@@ -725,7 +751,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
725
751
|
scheduler_output) as kv_connector_output:
|
|
726
752
|
# NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
|
|
727
753
|
# but one of them would be `None`
|
|
728
|
-
|
|
729
754
|
(self.kv_caches, hidden_states,
|
|
730
755
|
aux_hidden_states) = self.model_fn(
|
|
731
756
|
self.state,
|
|
@@ -733,10 +758,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
733
758
|
input_ids,
|
|
734
759
|
attn_metadata,
|
|
735
760
|
inputs_embeds,
|
|
761
|
+
input_positions,
|
|
736
762
|
tuple(self.layer_name_to_kvcache_index.items()),
|
|
737
763
|
lora_metadata,
|
|
764
|
+
intermediate_tensors,
|
|
765
|
+
self.is_first_rank,
|
|
766
|
+
self.is_last_rank,
|
|
738
767
|
)
|
|
739
|
-
|
|
768
|
+
if not get_pp_group().is_last_rank:
|
|
769
|
+
assert isinstance(hidden_states, JaxIntermediateTensors)
|
|
770
|
+
hidden_states.kv_connector_output = kv_connector_output
|
|
771
|
+
return attn_metadata, hidden_states
|
|
740
772
|
hidden_states = self._select_from_array_fn(hidden_states,
|
|
741
773
|
logits_indices)
|
|
742
774
|
logits = self.compute_logits_fn(
|
|
@@ -754,7 +786,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
754
786
|
aux_hidden_states=aux_hidden_states,
|
|
755
787
|
spec_decode_metadata=spec_decode_metadata,
|
|
756
788
|
kv_connector_output=kv_connector_output,
|
|
757
|
-
logits_indices_selector=logits_indices_selector
|
|
789
|
+
logits_indices_selector=logits_indices_selector,
|
|
790
|
+
padded_num_reqs=padded_num_reqs)
|
|
758
791
|
return attn_metadata, None
|
|
759
792
|
|
|
760
793
|
def _sample_from_logits(
|
|
@@ -768,23 +801,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
768
801
|
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
|
769
802
|
kv_connector_output: Optional[KVConnectorOutput],
|
|
770
803
|
logits_indices_selector: Optional[List[int]] = None,
|
|
804
|
+
padded_num_reqs: Optional[int] = None,
|
|
771
805
|
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
|
|
772
|
-
padded_num_reqs
|
|
773
|
-
|
|
806
|
+
if padded_num_reqs is None:
|
|
807
|
+
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
|
|
808
|
+
self.input_batch.num_reqs, self.max_num_reqs)
|
|
809
|
+
|
|
810
|
+
sharding = None
|
|
811
|
+
if self.dp_size > 1:
|
|
812
|
+
sharding = NamedSharding(self.mesh,
|
|
813
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
814
|
+
|
|
774
815
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
775
|
-
self.mesh, self.input_batch, padded_num_reqs)
|
|
816
|
+
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
|
|
817
|
+
|
|
818
|
+
# TODO(pooyam): Should we move this to `_prepare_inputs`?
|
|
819
|
+
if tpu_sampling_metadata.do_sampling:
|
|
820
|
+
self.rng_params_for_sampling, step_rng = jax.random.split(
|
|
821
|
+
self.rng_params_for_sampling)
|
|
822
|
+
else:
|
|
823
|
+
step_rng = self.rng_params_for_sampling
|
|
824
|
+
|
|
776
825
|
if spec_decode_metadata is None:
|
|
777
826
|
next_tokens = sample(
|
|
778
|
-
|
|
827
|
+
step_rng,
|
|
779
828
|
self.mesh,
|
|
780
829
|
logits,
|
|
781
830
|
tpu_sampling_metadata,
|
|
782
831
|
)
|
|
783
832
|
else:
|
|
833
|
+
if tpu_sampling_metadata.do_sampling:
|
|
834
|
+
bonus_rng, rejection_rng = jax.random.split(step_rng)
|
|
835
|
+
else:
|
|
836
|
+
bonus_rng = step_rng
|
|
837
|
+
rejection_rng = step_rng
|
|
784
838
|
bonus_logits = self._select_from_array_fn(
|
|
785
839
|
logits, spec_decode_metadata.bonus_logits_indices)
|
|
786
840
|
bonus_token_ids = sample(
|
|
787
|
-
|
|
841
|
+
bonus_rng,
|
|
788
842
|
self.mesh,
|
|
789
843
|
bonus_logits,
|
|
790
844
|
tpu_sampling_metadata,
|
|
@@ -798,7 +852,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
798
852
|
target_logits=target_logits,
|
|
799
853
|
bonus_token_ids=bonus_token_ids,
|
|
800
854
|
sampling_metadata=tpu_sampling_metadata,
|
|
801
|
-
key=
|
|
855
|
+
key=rejection_rng,
|
|
802
856
|
)
|
|
803
857
|
|
|
804
858
|
if tpu_sampling_metadata.logprobs:
|
|
@@ -856,10 +910,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
856
910
|
|
|
857
911
|
if logprobs is not None:
|
|
858
912
|
# Map logprobs back to the pre-dp shuffling order
|
|
859
|
-
logprobs_lists =
|
|
860
|
-
|
|
861
|
-
logprobs_lists = _reorder_logits_indices(
|
|
862
|
-
logprobs_lists, logits_indices_selector)
|
|
913
|
+
logprobs_lists = _jax_logprobs_to_lists(
|
|
914
|
+
logprobs, logits_indices_selector)
|
|
863
915
|
|
|
864
916
|
else:
|
|
865
917
|
logprobs_lists = None
|
|
@@ -929,10 +981,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
929
981
|
|
|
930
982
|
if logprobs is not None:
|
|
931
983
|
# Map logprobs back to the pre-dp shuffling order
|
|
932
|
-
logprobs_lists = logprobs
|
|
933
|
-
|
|
934
|
-
logprobs_lists = _reorder_logits_indices(
|
|
935
|
-
logprobs_lists, logits_indices_selector)
|
|
984
|
+
logprobs_lists = _jax_logprobs_to_lists(logprobs,
|
|
985
|
+
logits_indices_selector)
|
|
936
986
|
else:
|
|
937
987
|
logprobs_lists = None
|
|
938
988
|
|
|
@@ -1280,16 +1330,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1280
1330
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1281
1331
|
padded_total_num_scheduled_tokens]
|
|
1282
1332
|
|
|
1283
|
-
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1284
|
-
for dp_rank in range(dp_size):
|
|
1285
|
-
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1286
|
-
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1287
|
-
|
|
1288
|
-
block_tables[
|
|
1289
|
-
req_offset:req_offset + _num_reqs, :self.
|
|
1290
|
-
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1291
|
-
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1292
|
-
|
|
1293
1333
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
|
|
1294
1334
|
dp_size]
|
|
1295
1335
|
seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
|
|
@@ -1297,7 +1337,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1297
1337
|
_request_distribution = []
|
|
1298
1338
|
for dp_rank in range(dp_size):
|
|
1299
1339
|
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1300
|
-
|
|
1340
|
+
# The batch has been reordered by _reorder_batch so decode requests come first
|
|
1341
|
+
# Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
|
|
1342
|
+
num_decode_in_dp_rank = 0
|
|
1343
|
+
for req_id in req_ids_dp[dp_rank]:
|
|
1344
|
+
if scheduler_output.num_scheduled_tokens[req_id] == 1:
|
|
1345
|
+
num_decode_in_dp_rank += 1
|
|
1346
|
+
_request_distribution.append(
|
|
1347
|
+
[num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
|
|
1301
1348
|
request_distribution = np.array(_request_distribution).ravel()
|
|
1302
1349
|
|
|
1303
1350
|
use_spec_decode = len(
|
|
@@ -1331,20 +1378,59 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1331
1378
|
if self.uses_mrope:
|
|
1332
1379
|
positions = mrope_positions
|
|
1333
1380
|
|
|
1334
|
-
# Convert block_tables to 1D on cpu.
|
|
1335
|
-
block_tables = block_tables.reshape(-1)
|
|
1336
|
-
|
|
1337
1381
|
query_start_loc_cpu = query_start_loc
|
|
1338
1382
|
logits_indices_cpu = logits_indices
|
|
1339
1383
|
seq_lens_cpu = seq_lens
|
|
1340
1384
|
|
|
1341
|
-
(input_ids, positions,
|
|
1342
|
-
|
|
1385
|
+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
|
|
1386
|
+
request_distribution) = device_array(
|
|
1343
1387
|
self.mesh,
|
|
1344
|
-
(input_ids, positions,
|
|
1345
|
-
|
|
1388
|
+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
|
|
1389
|
+
request_distribution),
|
|
1346
1390
|
sharding=data_parallel_attn_sharding,
|
|
1347
1391
|
)
|
|
1392
|
+
|
|
1393
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1394
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
1395
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1396
|
+
self.kv_cache_config.kv_cache_groups):
|
|
1397
|
+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1398
|
+
max_num_reqs]
|
|
1399
|
+
for dp_rank in range(dp_size):
|
|
1400
|
+
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1401
|
+
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1402
|
+
|
|
1403
|
+
block_tables[
|
|
1404
|
+
req_offset:req_offset + _num_reqs, :self.
|
|
1405
|
+
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1406
|
+
kv_cache_gid].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1407
|
+
# Convert block_tables to 1D on cpu.
|
|
1408
|
+
block_tables = block_tables.reshape(-1)
|
|
1409
|
+
block_tables = device_array(
|
|
1410
|
+
self.mesh,
|
|
1411
|
+
(block_tables),
|
|
1412
|
+
sharding=data_parallel_attn_sharding,
|
|
1413
|
+
)
|
|
1414
|
+
|
|
1415
|
+
attention_metadata_gid = AttentionMetadata(
|
|
1416
|
+
input_positions=positions,
|
|
1417
|
+
block_tables=block_tables,
|
|
1418
|
+
seq_lens=seq_lens,
|
|
1419
|
+
query_start_loc=query_start_loc,
|
|
1420
|
+
request_distribution=request_distribution,
|
|
1421
|
+
)
|
|
1422
|
+
|
|
1423
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1424
|
+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1425
|
+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1426
|
+
|
|
1427
|
+
if not self.use_hybrid_kvcache:
|
|
1428
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
1429
|
+
else:
|
|
1430
|
+
for layer_name in kv_cache_group.layer_names:
|
|
1431
|
+
attention_metadata_per_layer[
|
|
1432
|
+
layer_name] = attention_metadata_gid
|
|
1433
|
+
|
|
1348
1434
|
# Async scheduling: substitute placeholder tokens for DP
|
|
1349
1435
|
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
|
|
1350
1436
|
# Collect all token indices that need substitution across all DP ranks
|
|
@@ -1373,25 +1459,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1373
1459
|
padded_total_num_scheduled_tokens,
|
|
1374
1460
|
)
|
|
1375
1461
|
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
query_start_loc=query_start_loc,
|
|
1381
|
-
request_distribution=request_distribution,
|
|
1382
|
-
)
|
|
1383
|
-
|
|
1384
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1385
|
-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1386
|
-
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1387
|
-
|
|
1462
|
+
if self.use_hybrid_kvcache:
|
|
1463
|
+
attention_metadata = attention_metadata_per_layer
|
|
1464
|
+
else:
|
|
1465
|
+
attention_metadata = uniform_attention_metadata
|
|
1388
1466
|
return (
|
|
1389
1467
|
input_ids,
|
|
1468
|
+
positions,
|
|
1390
1469
|
attention_metadata,
|
|
1391
1470
|
sampling_metadata,
|
|
1392
1471
|
logits_indices,
|
|
1393
1472
|
spec_decode_metadata,
|
|
1394
1473
|
logits_indices_selector,
|
|
1474
|
+
padded_num_reqs,
|
|
1395
1475
|
)
|
|
1396
1476
|
|
|
1397
1477
|
def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
|
|
@@ -1492,9 +1572,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1492
1572
|
positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
|
|
1493
1573
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1494
1574
|
padded_total_num_scheduled_tokens]
|
|
1495
|
-
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1496
|
-
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
|
1497
|
-
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
|
1498
1575
|
|
|
1499
1576
|
# TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
|
|
1500
1577
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
|
|
@@ -1523,16 +1600,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1523
1600
|
self.mesh, self.input_batch, padded_num_reqs)
|
|
1524
1601
|
if self.uses_mrope:
|
|
1525
1602
|
positions = mrope_positions
|
|
1526
|
-
|
|
1527
|
-
# Convert block_tables to 1D on cpu.
|
|
1528
|
-
block_tables = block_tables.reshape(-1)
|
|
1529
|
-
|
|
1530
1603
|
query_start_loc_cpu = query_start_loc
|
|
1531
1604
|
seq_lens_cpu = seq_lens
|
|
1532
|
-
|
|
1605
|
+
|
|
1606
|
+
(input_ids, positions, query_start_loc, seq_lens,
|
|
1533
1607
|
logits_indices, request_distribution) = device_array(
|
|
1534
|
-
self.mesh, (input_ids, positions,
|
|
1535
|
-
|
|
1608
|
+
self.mesh, (input_ids, positions, query_start_loc, seq_lens,
|
|
1609
|
+
logits_indices, request_distribution))
|
|
1610
|
+
|
|
1611
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1612
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
1613
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1614
|
+
self.kv_cache_config.kv_cache_groups):
|
|
1615
|
+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1616
|
+
max_num_reqs]
|
|
1617
|
+
block_tables[:num_reqs] = (
|
|
1618
|
+
self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
|
|
1619
|
+
[:num_reqs])
|
|
1620
|
+
# Convert block_tables to 1D on cpu.
|
|
1621
|
+
block_tables = block_tables.reshape(-1)
|
|
1622
|
+
block_tables = device_array(self.mesh, (block_tables))
|
|
1623
|
+
|
|
1624
|
+
attention_metadata_gid = AttentionMetadata(
|
|
1625
|
+
input_positions=positions,
|
|
1626
|
+
block_tables=block_tables,
|
|
1627
|
+
seq_lens=seq_lens,
|
|
1628
|
+
query_start_loc=query_start_loc,
|
|
1629
|
+
request_distribution=request_distribution)
|
|
1630
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1631
|
+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1632
|
+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1633
|
+
|
|
1634
|
+
if not self.use_hybrid_kvcache:
|
|
1635
|
+
# all layers share the same attention metadata
|
|
1636
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
1637
|
+
else:
|
|
1638
|
+
for layer_name in kv_cache_group.layer_names:
|
|
1639
|
+
attention_metadata_per_layer[
|
|
1640
|
+
layer_name] = attention_metadata_gid
|
|
1536
1641
|
|
|
1537
1642
|
if self.scheduler_config.async_scheduling and len(
|
|
1538
1643
|
token_in_tpu_cur_input_indices) > 0:
|
|
@@ -1545,20 +1650,15 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1545
1650
|
self.lora_utils.set_active_loras(
|
|
1546
1651
|
num_scheduled_tokens_per_req, total_num_scheduled_tokens,
|
|
1547
1652
|
padded_total_num_scheduled_tokens)
|
|
1548
|
-
|
|
1549
|
-
attention_metadata = AttentionMetadata(
|
|
1550
|
-
input_positions=positions,
|
|
1551
|
-
block_tables=block_tables,
|
|
1552
|
-
seq_lens=seq_lens,
|
|
1553
|
-
query_start_loc=query_start_loc,
|
|
1554
|
-
request_distribution=request_distribution)
|
|
1555
|
-
|
|
1556
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1557
|
-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1558
|
-
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1559
1653
|
logits_indices_selector = None
|
|
1560
|
-
|
|
1561
|
-
|
|
1654
|
+
|
|
1655
|
+
if self.use_hybrid_kvcache:
|
|
1656
|
+
attention_metadata = attention_metadata_per_layer
|
|
1657
|
+
else:
|
|
1658
|
+
attention_metadata = uniform_attention_metadata
|
|
1659
|
+
return (input_ids, positions, attention_metadata, sampling_metadata,
|
|
1660
|
+
logits_indices, spec_decode_metadata, logits_indices_selector,
|
|
1661
|
+
padded_num_reqs)
|
|
1562
1662
|
|
|
1563
1663
|
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1564
1664
|
mm_embeds: list[jax.Array]):
|
|
@@ -1618,3 +1718,34 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1618
1718
|
mappings=mappings,
|
|
1619
1719
|
transpose_keys=transpose_keys,
|
|
1620
1720
|
shard=shard)
|
|
1721
|
+
|
|
1722
|
+
def get_intermediate_tensor_spec(self, num_tokens: int):
|
|
1723
|
+
jax_dtype = to_jax_dtype(self.dtype)
|
|
1724
|
+
num_padded_tokens = runner_utils.get_padded_token_len(
|
|
1725
|
+
self.num_tokens_paddings, num_tokens)
|
|
1726
|
+
sharding = NamedSharding(self.mesh, PartitionSpec())
|
|
1727
|
+
hidden_size = self.model_config.get_hidden_size()
|
|
1728
|
+
spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
|
|
1729
|
+
dtype=jax_dtype,
|
|
1730
|
+
sharding=sharding)
|
|
1731
|
+
tensor_spec = {"hidden_states": spec, "residual": spec}
|
|
1732
|
+
return tensor_spec
|
|
1733
|
+
|
|
1734
|
+
def get_uuid_for_jax_transfer(self,
|
|
1735
|
+
scheduler_output: "VllmSchedulerOutput",
|
|
1736
|
+
rank: int, step: int) -> int:
|
|
1737
|
+
'''
|
|
1738
|
+
Get a uuid for jax.transfer, here we use the hash of
|
|
1739
|
+
scheduler_output + counter_step + sender's rank
|
|
1740
|
+
'''
|
|
1741
|
+
scheduler_output_str = ""
|
|
1742
|
+
if not scheduler_output.num_scheduled_tokens:
|
|
1743
|
+
scheduler_output_str = "empty_batch"
|
|
1744
|
+
else:
|
|
1745
|
+
scheduler_output_str = str(
|
|
1746
|
+
sorted(scheduler_output.num_scheduled_tokens.items()))
|
|
1747
|
+
unique_str = f'{scheduler_output_str} {step} {rank}'
|
|
1748
|
+
import hashlib
|
|
1749
|
+
hasher = hashlib.sha1()
|
|
1750
|
+
hasher.update(unique_str.encode('utf-8'))
|
|
1751
|
+
return int.from_bytes(hasher.digest()[:8], 'big')
|