tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_utils.py +16 -24
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/core_tpu.py +9 -17
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +11 -31
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/{common → jax}/sharding.py +5 -5
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/__init__.py +3 -7
- tpu_inference/layers/vllm/quantization/awq.py +3 -4
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
- tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +12 -46
- tpu_inference/models/jax/llama3.py +3 -4
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +2 -3
- tpu_inference/models/jax/qwen2_5_vl.py +50 -165
- tpu_inference/models/jax/qwen3.py +2 -3
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
- tpu_inference/platforms/tpu_platform.py +34 -47
- tpu_inference/runner/compilation_manager.py +60 -145
- tpu_inference/runner/kv_cache.py +2 -2
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +135 -283
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +15 -38
- tpu_inference/worker/tpu_worker.py +26 -163
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
- tests/test_envs.py +0 -203
- tpu_inference/layers/common/quant_methods.py +0 -8
- tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
|
@@ -10,23 +10,24 @@ import jax
|
|
|
10
10
|
import jax.numpy as jnp
|
|
11
11
|
import jaxtyping
|
|
12
12
|
import numpy as np
|
|
13
|
-
import
|
|
13
|
+
import torch
|
|
14
|
+
import vllm.envs as envs
|
|
14
15
|
from flax import nnx
|
|
15
16
|
from jax.experimental import mesh_utils
|
|
16
17
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
17
|
-
from torchax.ops.mappings import
|
|
18
|
+
from torchax.ops.mappings import j2t_dtype
|
|
18
19
|
from vllm.config import VllmConfig
|
|
19
|
-
from vllm.distributed import get_pp_group
|
|
20
20
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
21
21
|
has_kv_transfer_group)
|
|
22
22
|
from vllm.forward_context import set_forward_context
|
|
23
|
+
from vllm.sequence import IntermediateTensors
|
|
23
24
|
from vllm.tasks import SupportedTask
|
|
24
25
|
from vllm.utils.math_utils import cdiv
|
|
25
26
|
from vllm.v1.core.sched.output import GrammarOutput
|
|
26
27
|
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
27
28
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
28
29
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
|
29
|
-
DraftTokenIds, KVConnectorOutput,
|
|
30
|
+
DraftTokenIds, KVConnectorOutput,
|
|
30
31
|
ModelRunnerOutput)
|
|
31
32
|
from vllm.v1.request import Request
|
|
32
33
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
@@ -34,22 +35,19 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
|
34
35
|
KVConnectorModelRunnerMixin
|
|
35
36
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
36
37
|
|
|
37
|
-
import tpu_inference.envs as envs
|
|
38
38
|
from tpu_inference import utils as common_utils
|
|
39
39
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
|
-
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
41
|
-
MESH_AXIS_NAMES_2D,
|
|
42
|
-
ShardingAxisName,
|
|
43
|
-
ShardingConfigManager)
|
|
44
40
|
from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
|
|
45
41
|
from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
|
|
46
42
|
gather_logprobs, sample)
|
|
47
43
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
48
44
|
TPUSupportedSamplingMetadata
|
|
45
|
+
from tpu_inference.layers.jax.sharding import (MESH_AXIS_NAMES,
|
|
46
|
+
MESH_AXIS_NAMES_2D,
|
|
47
|
+
ShardingAxisName,
|
|
48
|
+
ShardingConfigManager)
|
|
49
49
|
from tpu_inference.logger import init_logger
|
|
50
50
|
from tpu_inference.models.common.model_loader import get_model
|
|
51
|
-
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
52
|
-
JaxIntermediateTensors
|
|
53
51
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
54
52
|
shard_put, transfer_state_with_mappings)
|
|
55
53
|
from tpu_inference.runner import utils as runner_utils
|
|
@@ -66,7 +64,7 @@ from tpu_inference.runner.structured_decoding_manager import \
|
|
|
66
64
|
StructuredDecodingManager
|
|
67
65
|
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
68
66
|
from tpu_inference.utils import (device_array, make_optimized_mesh,
|
|
69
|
-
time_function
|
|
67
|
+
time_function)
|
|
70
68
|
|
|
71
69
|
logger = init_logger(__name__)
|
|
72
70
|
|
|
@@ -80,6 +78,17 @@ DUMMY_METADATA = AttentionMetadata(
|
|
|
80
78
|
request_distribution=[0, 0, 0],
|
|
81
79
|
)
|
|
82
80
|
|
|
81
|
+
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
82
|
+
"half": torch.half,
|
|
83
|
+
"bfloat16": torch.bfloat16,
|
|
84
|
+
"float": torch.float,
|
|
85
|
+
"fp8": torch.float8_e4m3fn,
|
|
86
|
+
"fp8_e4m3": torch.float8_e4m3fn,
|
|
87
|
+
"fp8_e5m2": torch.float8_e5m2,
|
|
88
|
+
"int8": torch.int8,
|
|
89
|
+
"uint8": torch.uint8,
|
|
90
|
+
}
|
|
91
|
+
|
|
83
92
|
|
|
84
93
|
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
85
94
|
"""Holds asynchronous model output specifically from a TPU runner.
|
|
@@ -144,7 +153,6 @@ class ExecuteModelState:
|
|
|
144
153
|
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
|
145
154
|
kv_connector_output: Optional[KVConnectorOutput]
|
|
146
155
|
logits_indices_selector: Optional[List[int]] = None
|
|
147
|
-
padded_num_reqs: Optional[int] = None
|
|
148
156
|
|
|
149
157
|
|
|
150
158
|
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
|
|
@@ -182,40 +190,12 @@ def _substitute_placeholder_token(
|
|
|
182
190
|
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
|
|
183
191
|
|
|
184
192
|
|
|
185
|
-
def _jax_logprobs_to_lists(logprobs_tensors,
|
|
186
|
-
logits_indices_selector=None,
|
|
187
|
-
cu_num_generated_tokens=None):
|
|
188
|
-
"""Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
|
|
189
|
-
log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
|
|
190
|
-
logprobs_list = logprobs_tensors.logprobs.tolist()
|
|
191
|
-
selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
|
|
192
|
-
|
|
193
|
-
if logits_indices_selector is not None:
|
|
194
|
-
log_token_ids_list = [
|
|
195
|
-
log_token_ids_list[i] for i in logits_indices_selector
|
|
196
|
-
]
|
|
197
|
-
logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
|
|
198
|
-
selected_token_ranks_list = [
|
|
199
|
-
selected_token_ranks_list[i] for i in logits_indices_selector
|
|
200
|
-
]
|
|
201
|
-
|
|
202
|
-
return LogprobsLists(
|
|
203
|
-
logprob_token_ids=np.asarray(log_token_ids_list),
|
|
204
|
-
logprobs=np.asarray(logprobs_list),
|
|
205
|
-
sampled_token_ranks=np.asarray(selected_token_ranks_list),
|
|
206
|
-
cu_num_generated_tokens=cu_num_generated_tokens,
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
|
|
210
193
|
class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
211
194
|
|
|
212
195
|
def __init__(
|
|
213
196
|
self,
|
|
214
197
|
vllm_config: VllmConfig,
|
|
215
198
|
devices: List[Any],
|
|
216
|
-
rank: int = 0,
|
|
217
|
-
is_first_rank: bool = True,
|
|
218
|
-
is_last_rank: bool = True,
|
|
219
199
|
):
|
|
220
200
|
self.vllm_config = vllm_config
|
|
221
201
|
self.model_config = vllm_config.model_config
|
|
@@ -234,9 +214,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
234
214
|
self.maybe_forbid_compile = runner_utils.ForbidCompile(
|
|
235
215
|
) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
|
|
236
216
|
self.dp_size = self.vllm_config.sharding_config.total_dp_size
|
|
237
|
-
self.rank = rank
|
|
238
|
-
self.is_first_rank = is_first_rank
|
|
239
|
-
self.is_last_rank = is_last_rank
|
|
240
217
|
|
|
241
218
|
self._init_random()
|
|
242
219
|
self._init_mesh()
|
|
@@ -247,21 +224,31 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
247
224
|
|
|
248
225
|
# Delegate functions to specific manager classes.
|
|
249
226
|
self.compilation_manager = CompilationManager(self)
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
self)
|
|
253
|
-
self.structured_decoding_manager = StructuredDecodingManager(self)
|
|
227
|
+
self.speculative_decoding_manager = SpeculativeDecodingManager(self)
|
|
228
|
+
self.structured_decoding_manager = StructuredDecodingManager(self)
|
|
254
229
|
self.kv_cache_manager = KVCacheManager(self)
|
|
255
230
|
self.mm_manager = MultiModalManager(self)
|
|
256
231
|
self.persistent_batch_manager = PersistentBatchManager(
|
|
257
232
|
self.requests, self.input_batch, self.encoder_cache,
|
|
258
|
-
self.uses_mrope, self.model_config
|
|
233
|
+
self.uses_mrope, self.model_config)
|
|
259
234
|
self.lora_utils = LoraUtils(self)
|
|
260
235
|
|
|
261
|
-
|
|
262
|
-
if cache_dtype == "auto":
|
|
263
|
-
|
|
264
|
-
|
|
236
|
+
cache_config = self.cache_config
|
|
237
|
+
if cache_config.cache_dtype == "auto":
|
|
238
|
+
model_dtype = self.dtype
|
|
239
|
+
if isinstance(model_dtype, str):
|
|
240
|
+
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
|
241
|
+
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
|
|
242
|
+
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
|
|
243
|
+
elif isinstance(model_dtype, torch.dtype):
|
|
244
|
+
self.kv_cache_dtype = model_dtype
|
|
245
|
+
else:
|
|
246
|
+
raise ValueError(
|
|
247
|
+
"KV cache is unsupported for model_dtype of %s",
|
|
248
|
+
model_dtype)
|
|
249
|
+
else:
|
|
250
|
+
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
|
|
251
|
+
cache_config.cache_dtype]
|
|
265
252
|
|
|
266
253
|
self._pre_async_results: AsyncPreResults | None = None
|
|
267
254
|
self._substitute_placeholder_token_fn = _substitute_placeholder_token
|
|
@@ -275,7 +262,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
275
262
|
self.rng_key = jax.random.key(self.model_config.seed)
|
|
276
263
|
|
|
277
264
|
def _init_mesh(self) -> None:
|
|
278
|
-
if
|
|
265
|
+
if os.getenv("NEW_MODEL_DESIGN", False):
|
|
279
266
|
self.mesh = self._create_new_model_mesh()
|
|
280
267
|
else:
|
|
281
268
|
# NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
|
|
@@ -286,7 +273,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
286
273
|
logger.info(f"Init mesh | mesh={self.mesh}")
|
|
287
274
|
|
|
288
275
|
def _create_new_model_mesh(self) -> jax.sharding.Mesh:
|
|
289
|
-
num_slices =
|
|
276
|
+
num_slices = int(os.environ.get('NUM_SLICES', 1))
|
|
290
277
|
|
|
291
278
|
logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
|
|
292
279
|
f"num_slices={num_slices}")
|
|
@@ -355,7 +342,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
355
342
|
devices=self.devices)
|
|
356
343
|
|
|
357
344
|
def _init_phased_profiling(self) -> None:
|
|
358
|
-
self.phased_profiling_dir =
|
|
345
|
+
self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
|
|
359
346
|
self.phase_based_profiler = None
|
|
360
347
|
if self.phased_profiling_dir:
|
|
361
348
|
self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
|
|
@@ -397,7 +384,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
397
384
|
min_token_size=max(16, self.dp_size),
|
|
398
385
|
max_token_size=scheduler_config.max_num_batched_tokens *
|
|
399
386
|
self.dp_size,
|
|
400
|
-
padding_gap=
|
|
387
|
+
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
|
401
388
|
self.num_tokens_paddings_per_dp = [
|
|
402
389
|
padding // self.dp_size for padding in self.num_tokens_paddings
|
|
403
390
|
]
|
|
@@ -421,14 +408,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
421
408
|
|
|
422
409
|
self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
423
410
|
self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
# in kv_cache_manager's maybe_reinitialize_input_batch.
|
|
427
|
-
self.block_tables_cpu = [
|
|
428
|
-
np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
|
|
429
|
-
dtype=np.int32)
|
|
430
|
-
]
|
|
431
|
-
|
|
411
|
+
self.block_table_cpu = np.zeros(
|
|
412
|
+
(self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
|
|
432
413
|
self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
|
|
433
414
|
dtype=np.int32)
|
|
434
415
|
self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
|
|
@@ -462,6 +443,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
462
443
|
|
|
463
444
|
# tensors for structured decoding
|
|
464
445
|
self.vocab_size = self.model_config.get_vocab_size()
|
|
446
|
+
if self.lora_config is not None:
|
|
447
|
+
# lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
|
|
448
|
+
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
|
465
449
|
self.grammar_bitmask_cpu = np.zeros(
|
|
466
450
|
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
|
|
467
451
|
dtype=np.int32,
|
|
@@ -506,14 +490,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
506
490
|
|
|
507
491
|
self.rng_params_for_sampling = nnx.Rngs(
|
|
508
492
|
jax.random.key(self.model_config.seed)).params()
|
|
509
|
-
self.is_multimodal_model = (
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
self.model_config.hf_config, "architectures"
|
|
513
|
-
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
514
|
-
and len(self.model_config.hf_config.architectures) >= 1
|
|
515
|
-
and self.model_config.hf_config.architectures[0]
|
|
516
|
-
!= "Llama4ForConditionalGeneration")
|
|
493
|
+
self.is_multimodal_model = (self.model_config.is_multimodal_model
|
|
494
|
+
and self.get_multimodal_embeddings_fn
|
|
495
|
+
is not None)
|
|
517
496
|
|
|
518
497
|
logger.info(f"Init model | "
|
|
519
498
|
f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
|
|
@@ -526,7 +505,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
526
505
|
|
|
527
506
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
528
507
|
self.kv_cache_config = kv_cache_config
|
|
529
|
-
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
|
|
530
508
|
self.kv_caches = []
|
|
531
509
|
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
|
|
532
510
|
if has_kv_transfer_group():
|
|
@@ -539,12 +517,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
539
517
|
def execute_model(
|
|
540
518
|
self,
|
|
541
519
|
scheduler_output: "VllmSchedulerOutput",
|
|
542
|
-
intermediate_tensors: Optional[
|
|
543
|
-
) -> ModelRunnerOutput |
|
|
520
|
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
521
|
+
) -> ModelRunnerOutput | None:
|
|
544
522
|
if self.execute_model_state is not None:
|
|
545
523
|
raise RuntimeError("State error: sample_tokens() must be called "
|
|
546
524
|
"after execute_model() returns None.")
|
|
547
|
-
_, output = self._execute_model(scheduler_output
|
|
525
|
+
_, output = self._execute_model(scheduler_output)
|
|
548
526
|
return output
|
|
549
527
|
|
|
550
528
|
def sample_tokens(
|
|
@@ -557,17 +535,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
557
535
|
|
|
558
536
|
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
559
537
|
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
560
|
-
logits_indices_selector
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
self.execute_model_state.padded_num_reqs)
|
|
538
|
+
logits_indices_selector) = (
|
|
539
|
+
self.execute_model_state.scheduler_output,
|
|
540
|
+
self.execute_model_state.attn_metadata,
|
|
541
|
+
self.execute_model_state.input_ids,
|
|
542
|
+
self.execute_model_state.hidden_states,
|
|
543
|
+
self.execute_model_state.logits,
|
|
544
|
+
self.execute_model_state.aux_hidden_states,
|
|
545
|
+
self.execute_model_state.spec_decode_metadata,
|
|
546
|
+
self.execute_model_state.kv_connector_output,
|
|
547
|
+
self.execute_model_state.logits_indices_selector)
|
|
571
548
|
self.execute_model_state = None
|
|
572
549
|
|
|
573
550
|
if grammar_output is not None:
|
|
@@ -581,10 +558,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
581
558
|
logits,
|
|
582
559
|
arange,
|
|
583
560
|
)
|
|
584
|
-
return self._sample_from_logits(
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
561
|
+
return self._sample_from_logits(scheduler_output, attn_metadata,
|
|
562
|
+
input_ids, hidden_states, logits,
|
|
563
|
+
aux_hidden_states,
|
|
564
|
+
spec_decode_metadata,
|
|
565
|
+
kv_connector_output,
|
|
566
|
+
logits_indices_selector)
|
|
588
567
|
|
|
589
568
|
def _modify_prev_results(self):
|
|
590
569
|
# If copy to host has not been done, we just wait.
|
|
@@ -670,9 +649,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
670
649
|
def _execute_model(
|
|
671
650
|
self,
|
|
672
651
|
scheduler_output: "VllmSchedulerOutput",
|
|
673
|
-
|
|
674
|
-
) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
|
|
675
|
-
| None]:
|
|
652
|
+
) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
|
|
676
653
|
self.persistent_batch_manager.update_states(
|
|
677
654
|
scheduler_output, self.get_mrope_input_positions_fn)
|
|
678
655
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
@@ -695,23 +672,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
695
672
|
# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
|
|
696
673
|
(
|
|
697
674
|
input_ids,
|
|
698
|
-
input_positions,
|
|
699
675
|
attn_metadata,
|
|
700
676
|
_,
|
|
701
677
|
logits_indices,
|
|
702
678
|
spec_decode_metadata,
|
|
703
679
|
logits_indices_selector,
|
|
704
|
-
padded_num_reqs,
|
|
705
680
|
) = self._prepare_inputs(scheduler_output)
|
|
706
681
|
|
|
707
|
-
is_llama_guard_4 = (
|
|
708
|
-
hasattr(
|
|
709
|
-
self.model_config.hf_config, "architectures"
|
|
710
|
-
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
711
|
-
and len(self.model_config.hf_config.architectures) >= 1
|
|
712
|
-
and self.model_config.hf_config.architectures[0]
|
|
713
|
-
== "Llama4ForConditionalGeneration")
|
|
714
|
-
|
|
715
682
|
# multi-modal support
|
|
716
683
|
if self.is_multimodal_model:
|
|
717
684
|
# Run the multimodal encoder if any.
|
|
@@ -719,13 +686,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
719
686
|
self.mm_manager.execute_mm_encoder(scheduler_output)
|
|
720
687
|
mm_embeds = self.mm_manager.gather_mm_embeddings(
|
|
721
688
|
scheduler_output, input_ids.shape[0])
|
|
722
|
-
#TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
|
|
723
|
-
elif is_llama_guard_4 and any(
|
|
724
|
-
self.mm_manager.runner.requests[req_id].mm_features
|
|
725
|
-
for req_id in self.mm_manager.runner.input_batch.req_ids):
|
|
726
|
-
raise NotImplementedError(
|
|
727
|
-
"Llama Guard 4 (JAX) currently supports only text inputs. "
|
|
728
|
-
"Multimodal processing not yet implemented.")
|
|
729
689
|
else:
|
|
730
690
|
mm_embeds = []
|
|
731
691
|
|
|
@@ -750,6 +710,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
750
710
|
scheduler_output) as kv_connector_output:
|
|
751
711
|
# NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
|
|
752
712
|
# but one of them would be `None`
|
|
713
|
+
|
|
753
714
|
(self.kv_caches, hidden_states,
|
|
754
715
|
aux_hidden_states) = self.model_fn(
|
|
755
716
|
self.state,
|
|
@@ -757,17 +718,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
757
718
|
input_ids,
|
|
758
719
|
attn_metadata,
|
|
759
720
|
inputs_embeds,
|
|
760
|
-
input_positions,
|
|
761
721
|
tuple(self.layer_name_to_kvcache_index.items()),
|
|
762
722
|
lora_metadata,
|
|
763
|
-
intermediate_tensors,
|
|
764
|
-
self.is_first_rank,
|
|
765
|
-
self.is_last_rank,
|
|
766
723
|
)
|
|
767
|
-
|
|
768
|
-
assert isinstance(hidden_states, JaxIntermediateTensors)
|
|
769
|
-
hidden_states.kv_connector_output = kv_connector_output
|
|
770
|
-
return attn_metadata, hidden_states
|
|
724
|
+
|
|
771
725
|
hidden_states = self._select_from_array_fn(hidden_states,
|
|
772
726
|
logits_indices)
|
|
773
727
|
logits = self.compute_logits_fn(
|
|
@@ -785,8 +739,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
785
739
|
aux_hidden_states=aux_hidden_states,
|
|
786
740
|
spec_decode_metadata=spec_decode_metadata,
|
|
787
741
|
kv_connector_output=kv_connector_output,
|
|
788
|
-
logits_indices_selector=logits_indices_selector
|
|
789
|
-
padded_num_reqs=padded_num_reqs)
|
|
742
|
+
logits_indices_selector=logits_indices_selector)
|
|
790
743
|
return attn_metadata, None
|
|
791
744
|
|
|
792
745
|
def _sample_from_logits(
|
|
@@ -800,44 +753,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
800
753
|
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
|
801
754
|
kv_connector_output: Optional[KVConnectorOutput],
|
|
802
755
|
logits_indices_selector: Optional[List[int]] = None,
|
|
803
|
-
padded_num_reqs: Optional[int] = None,
|
|
804
756
|
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
self.input_batch.num_reqs, self.max_num_reqs)
|
|
808
|
-
|
|
809
|
-
sharding = None
|
|
810
|
-
if self.dp_size > 1:
|
|
811
|
-
sharding = NamedSharding(self.mesh,
|
|
812
|
-
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
813
|
-
|
|
757
|
+
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
|
|
758
|
+
self.input_batch.num_reqs, self.max_num_reqs)
|
|
814
759
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
815
|
-
self.mesh, self.input_batch, padded_num_reqs
|
|
816
|
-
|
|
817
|
-
# TODO(pooyam): Should we move this to `_prepare_inputs`?
|
|
818
|
-
if tpu_sampling_metadata.do_sampling:
|
|
819
|
-
self.rng_params_for_sampling, step_rng = jax.random.split(
|
|
820
|
-
self.rng_params_for_sampling)
|
|
821
|
-
else:
|
|
822
|
-
step_rng = self.rng_params_for_sampling
|
|
823
|
-
|
|
760
|
+
self.mesh, self.input_batch, padded_num_reqs)
|
|
824
761
|
if spec_decode_metadata is None:
|
|
825
762
|
next_tokens = sample(
|
|
826
|
-
|
|
763
|
+
self.rng_params_for_sampling,
|
|
827
764
|
self.mesh,
|
|
828
765
|
logits,
|
|
829
766
|
tpu_sampling_metadata,
|
|
830
767
|
)
|
|
831
768
|
else:
|
|
832
|
-
if tpu_sampling_metadata.do_sampling:
|
|
833
|
-
bonus_rng, rejection_rng = jax.random.split(step_rng)
|
|
834
|
-
else:
|
|
835
|
-
bonus_rng = step_rng
|
|
836
|
-
rejection_rng = step_rng
|
|
837
769
|
bonus_logits = self._select_from_array_fn(
|
|
838
770
|
logits, spec_decode_metadata.bonus_logits_indices)
|
|
839
771
|
bonus_token_ids = sample(
|
|
840
|
-
|
|
772
|
+
self.rng_params_for_sampling,
|
|
841
773
|
self.mesh,
|
|
842
774
|
bonus_logits,
|
|
843
775
|
tpu_sampling_metadata,
|
|
@@ -851,7 +783,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
851
783
|
target_logits=target_logits,
|
|
852
784
|
bonus_token_ids=bonus_token_ids,
|
|
853
785
|
sampling_metadata=tpu_sampling_metadata,
|
|
854
|
-
key=
|
|
786
|
+
key=self.rng_params_for_sampling,
|
|
855
787
|
)
|
|
856
788
|
|
|
857
789
|
if tpu_sampling_metadata.logprobs:
|
|
@@ -908,10 +840,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
908
840
|
logits_indices_selector)
|
|
909
841
|
|
|
910
842
|
if logprobs is not None:
|
|
911
|
-
|
|
912
|
-
logprobs_lists = _jax_logprobs_to_lists(
|
|
913
|
-
logprobs, logits_indices_selector)
|
|
914
|
-
|
|
843
|
+
logprobs_lists = logprobs.tolists()
|
|
915
844
|
else:
|
|
916
845
|
logprobs_lists = None
|
|
917
846
|
|
|
@@ -979,9 +908,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
979
908
|
req_state.output_token_ids.extend(sampled_ids)
|
|
980
909
|
|
|
981
910
|
if logprobs is not None:
|
|
982
|
-
|
|
983
|
-
logprobs_lists = _jax_logprobs_to_lists(logprobs,
|
|
984
|
-
logits_indices_selector)
|
|
911
|
+
logprobs_lists = logprobs.tolists()
|
|
985
912
|
else:
|
|
986
913
|
logprobs_lists = None
|
|
987
914
|
|
|
@@ -1329,6 +1256,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1329
1256
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1330
1257
|
padded_total_num_scheduled_tokens]
|
|
1331
1258
|
|
|
1259
|
+
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1260
|
+
for dp_rank in range(dp_size):
|
|
1261
|
+
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1262
|
+
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1263
|
+
|
|
1264
|
+
block_tables[
|
|
1265
|
+
req_offset:req_offset + _num_reqs, :self.
|
|
1266
|
+
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1267
|
+
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1268
|
+
|
|
1332
1269
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
|
|
1333
1270
|
dp_size]
|
|
1334
1271
|
seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
|
|
@@ -1370,59 +1307,20 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1370
1307
|
if self.uses_mrope:
|
|
1371
1308
|
positions = mrope_positions
|
|
1372
1309
|
|
|
1310
|
+
# Convert block_tables to 1D on cpu.
|
|
1311
|
+
block_tables = block_tables.reshape(-1)
|
|
1312
|
+
|
|
1373
1313
|
query_start_loc_cpu = query_start_loc
|
|
1374
1314
|
logits_indices_cpu = logits_indices
|
|
1375
1315
|
seq_lens_cpu = seq_lens
|
|
1376
1316
|
|
|
1377
|
-
(input_ids, positions, query_start_loc, seq_lens,
|
|
1378
|
-
request_distribution) = device_array(
|
|
1317
|
+
(input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
1318
|
+
logits_indices, request_distribution, logits_indices) = device_array(
|
|
1379
1319
|
self.mesh,
|
|
1380
|
-
(input_ids, positions, query_start_loc, seq_lens,
|
|
1381
|
-
request_distribution),
|
|
1320
|
+
(input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
1321
|
+
logits_indices, request_distribution, logits_indices),
|
|
1382
1322
|
sharding=data_parallel_attn_sharding,
|
|
1383
1323
|
)
|
|
1384
|
-
|
|
1385
|
-
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1386
|
-
uniform_attention_metadata: AttentionMetadata = None
|
|
1387
|
-
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1388
|
-
self.kv_cache_config.kv_cache_groups):
|
|
1389
|
-
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1390
|
-
max_num_reqs]
|
|
1391
|
-
for dp_rank in range(dp_size):
|
|
1392
|
-
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1393
|
-
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1394
|
-
|
|
1395
|
-
block_tables[
|
|
1396
|
-
req_offset:req_offset + _num_reqs, :self.
|
|
1397
|
-
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1398
|
-
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1399
|
-
# Convert block_tables to 1D on cpu.
|
|
1400
|
-
block_tables = block_tables.reshape(-1)
|
|
1401
|
-
block_tables = device_array(
|
|
1402
|
-
self.mesh,
|
|
1403
|
-
(block_tables),
|
|
1404
|
-
sharding=data_parallel_attn_sharding,
|
|
1405
|
-
)
|
|
1406
|
-
|
|
1407
|
-
attention_metadata_gid = AttentionMetadata(
|
|
1408
|
-
input_positions=positions,
|
|
1409
|
-
block_tables=block_tables,
|
|
1410
|
-
seq_lens=seq_lens,
|
|
1411
|
-
query_start_loc=query_start_loc,
|
|
1412
|
-
request_distribution=request_distribution,
|
|
1413
|
-
)
|
|
1414
|
-
|
|
1415
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1416
|
-
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1417
|
-
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1418
|
-
|
|
1419
|
-
if not self.use_hybrid_kvcache:
|
|
1420
|
-
uniform_attention_metadata = attention_metadata_gid
|
|
1421
|
-
else:
|
|
1422
|
-
for layer_name in kv_cache_group.layer_names:
|
|
1423
|
-
attention_metadata_per_layer[
|
|
1424
|
-
layer_name] = attention_metadata_gid
|
|
1425
|
-
|
|
1426
1324
|
# Async scheduling: substitute placeholder tokens for DP
|
|
1427
1325
|
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
|
|
1428
1326
|
# Collect all token indices that need substitution across all DP ranks
|
|
@@ -1451,19 +1349,25 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1451
1349
|
padded_total_num_scheduled_tokens,
|
|
1452
1350
|
)
|
|
1453
1351
|
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1352
|
+
attention_metadata = AttentionMetadata(
|
|
1353
|
+
input_positions=positions,
|
|
1354
|
+
block_tables=block_tables,
|
|
1355
|
+
seq_lens=seq_lens,
|
|
1356
|
+
query_start_loc=query_start_loc,
|
|
1357
|
+
request_distribution=request_distribution,
|
|
1358
|
+
)
|
|
1359
|
+
|
|
1360
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1361
|
+
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1362
|
+
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1363
|
+
|
|
1458
1364
|
return (
|
|
1459
1365
|
input_ids,
|
|
1460
|
-
positions,
|
|
1461
1366
|
attention_metadata,
|
|
1462
1367
|
sampling_metadata,
|
|
1463
1368
|
logits_indices,
|
|
1464
1369
|
spec_decode_metadata,
|
|
1465
1370
|
logits_indices_selector,
|
|
1466
|
-
padded_num_reqs,
|
|
1467
1371
|
)
|
|
1468
1372
|
|
|
1469
1373
|
def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
|
|
@@ -1564,6 +1468,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1564
1468
|
positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
|
|
1565
1469
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1566
1470
|
padded_total_num_scheduled_tokens]
|
|
1471
|
+
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1472
|
+
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
|
1473
|
+
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
|
1567
1474
|
|
|
1568
1475
|
# TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
|
|
1569
1476
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
|
|
@@ -1592,44 +1499,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1592
1499
|
self.mesh, self.input_batch, padded_num_reqs)
|
|
1593
1500
|
if self.uses_mrope:
|
|
1594
1501
|
positions = mrope_positions
|
|
1502
|
+
|
|
1503
|
+
# Convert block_tables to 1D on cpu.
|
|
1504
|
+
block_tables = block_tables.reshape(-1)
|
|
1505
|
+
|
|
1595
1506
|
query_start_loc_cpu = query_start_loc
|
|
1596
1507
|
seq_lens_cpu = seq_lens
|
|
1597
|
-
|
|
1598
|
-
(input_ids, positions, query_start_loc, seq_lens,
|
|
1508
|
+
(input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
1599
1509
|
logits_indices, request_distribution) = device_array(
|
|
1600
|
-
self.mesh, (input_ids, positions,
|
|
1601
|
-
logits_indices, request_distribution))
|
|
1602
|
-
|
|
1603
|
-
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1604
|
-
uniform_attention_metadata: AttentionMetadata = None
|
|
1605
|
-
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1606
|
-
self.kv_cache_config.kv_cache_groups):
|
|
1607
|
-
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1608
|
-
max_num_reqs]
|
|
1609
|
-
block_tables[:num_reqs] = (
|
|
1610
|
-
self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
|
|
1611
|
-
[:num_reqs])
|
|
1612
|
-
# Convert block_tables to 1D on cpu.
|
|
1613
|
-
block_tables = block_tables.reshape(-1)
|
|
1614
|
-
block_tables = device_array(self.mesh, (block_tables))
|
|
1615
|
-
|
|
1616
|
-
attention_metadata_gid = AttentionMetadata(
|
|
1617
|
-
input_positions=positions,
|
|
1618
|
-
block_tables=block_tables,
|
|
1619
|
-
seq_lens=seq_lens,
|
|
1620
|
-
query_start_loc=query_start_loc,
|
|
1621
|
-
request_distribution=request_distribution)
|
|
1622
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1623
|
-
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1624
|
-
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1625
|
-
|
|
1626
|
-
if not self.use_hybrid_kvcache:
|
|
1627
|
-
# all layers share the same attention metadata
|
|
1628
|
-
uniform_attention_metadata = attention_metadata_gid
|
|
1629
|
-
else:
|
|
1630
|
-
for layer_name in kv_cache_group.layer_names:
|
|
1631
|
-
attention_metadata_per_layer[
|
|
1632
|
-
layer_name] = attention_metadata_gid
|
|
1510
|
+
self.mesh, (input_ids, positions, block_tables, query_start_loc,
|
|
1511
|
+
seq_lens, logits_indices, request_distribution))
|
|
1633
1512
|
|
|
1634
1513
|
if self.scheduler_config.async_scheduling and len(
|
|
1635
1514
|
token_in_tpu_cur_input_indices) > 0:
|
|
@@ -1642,15 +1521,20 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1642
1521
|
self.lora_utils.set_active_loras(
|
|
1643
1522
|
num_scheduled_tokens_per_req, total_num_scheduled_tokens,
|
|
1644
1523
|
padded_total_num_scheduled_tokens)
|
|
1645
|
-
logits_indices_selector = None
|
|
1646
1524
|
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1650
|
-
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1525
|
+
attention_metadata = AttentionMetadata(
|
|
1526
|
+
input_positions=positions,
|
|
1527
|
+
block_tables=block_tables,
|
|
1528
|
+
seq_lens=seq_lens,
|
|
1529
|
+
query_start_loc=query_start_loc,
|
|
1530
|
+
request_distribution=request_distribution)
|
|
1531
|
+
|
|
1532
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1533
|
+
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1534
|
+
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1535
|
+
logits_indices_selector = None
|
|
1536
|
+
return (input_ids, attention_metadata, sampling_metadata,
|
|
1537
|
+
logits_indices, spec_decode_metadata, logits_indices_selector)
|
|
1654
1538
|
|
|
1655
1539
|
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1656
1540
|
mm_embeds: list[jax.Array]):
|
|
@@ -1710,35 +1594,3 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1710
1594
|
mappings=mappings,
|
|
1711
1595
|
transpose_keys=transpose_keys,
|
|
1712
1596
|
shard=shard)
|
|
1713
|
-
|
|
1714
|
-
def get_intermediate_tensor_spec(self, num_tokens: int):
|
|
1715
|
-
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
|
|
1716
|
-
jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
|
|
1717
|
-
num_padded_tokens = runner_utils.get_padded_token_len(
|
|
1718
|
-
self.num_tokens_paddings, num_tokens)
|
|
1719
|
-
sharding = NamedSharding(self.mesh, PartitionSpec())
|
|
1720
|
-
hidden_size = self.model_config.get_hidden_size()
|
|
1721
|
-
spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
|
|
1722
|
-
dtype=jax_dtype,
|
|
1723
|
-
sharding=sharding)
|
|
1724
|
-
tensor_spec = {"hidden_states": spec, "residual": spec}
|
|
1725
|
-
return tensor_spec
|
|
1726
|
-
|
|
1727
|
-
def get_uuid_for_jax_transfer(self,
|
|
1728
|
-
scheduler_output: "VllmSchedulerOutput",
|
|
1729
|
-
rank: int, step: int) -> int:
|
|
1730
|
-
'''
|
|
1731
|
-
Get a uuid for jax.transfer, here we use the hash of
|
|
1732
|
-
scheduler_output + counter_step + sender's rank
|
|
1733
|
-
'''
|
|
1734
|
-
scheduler_output_str = ""
|
|
1735
|
-
if not scheduler_output.num_scheduled_tokens:
|
|
1736
|
-
scheduler_output_str = "empty_batch"
|
|
1737
|
-
else:
|
|
1738
|
-
scheduler_output_str = str(
|
|
1739
|
-
sorted(scheduler_output.num_scheduled_tokens.items()))
|
|
1740
|
-
unique_str = f'{scheduler_output_str} {step} {rank}'
|
|
1741
|
-
import hashlib
|
|
1742
|
-
hasher = hashlib.sha1()
|
|
1743
|
-
hasher.update(unique_str.encode('utf-8'))
|
|
1744
|
-
return int.from_bytes(hasher.digest()[:8], 'big')
|