tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +32 -11
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +61 -8
- tpu_inference/executors/ray_distributed_executor.py +31 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +45 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
- tpu_inference/platforms/tpu_platform.py +28 -22
- tpu_inference/runner/compilation_manager.py +144 -59
- tpu_inference/runner/kv_cache_manager.py +17 -18
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +271 -147
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +36 -13
- tpu_inference/worker/tpu_worker.py +162 -25
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
|
@@ -10,17 +10,16 @@ import jax
|
|
|
10
10
|
import jax.numpy as jnp
|
|
11
11
|
import jaxtyping
|
|
12
12
|
import numpy as np
|
|
13
|
-
import
|
|
14
|
-
import vllm.envs as envs
|
|
13
|
+
import vllm.envs as vllm_envs
|
|
15
14
|
from flax import nnx
|
|
16
15
|
from jax.experimental import mesh_utils
|
|
17
16
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
18
|
-
from torchax.ops.mappings import
|
|
17
|
+
from torchax.ops.mappings import t2j_dtype
|
|
19
18
|
from vllm.config import VllmConfig
|
|
19
|
+
from vllm.distributed import get_pp_group
|
|
20
20
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
21
21
|
has_kv_transfer_group)
|
|
22
22
|
from vllm.forward_context import set_forward_context
|
|
23
|
-
from vllm.sequence import IntermediateTensors
|
|
24
23
|
from vllm.tasks import SupportedTask
|
|
25
24
|
from vllm.utils.math_utils import cdiv
|
|
26
25
|
from vllm.v1.core.sched.output import GrammarOutput
|
|
@@ -35,6 +34,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
|
35
34
|
KVConnectorModelRunnerMixin
|
|
36
35
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
37
36
|
|
|
37
|
+
import tpu_inference.envs as envs
|
|
38
38
|
from tpu_inference import utils as common_utils
|
|
39
39
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
40
|
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
@@ -48,6 +48,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
|
48
48
|
TPUSupportedSamplingMetadata
|
|
49
49
|
from tpu_inference.logger import init_logger
|
|
50
50
|
from tpu_inference.models.common.model_loader import get_model
|
|
51
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
52
|
+
JaxIntermediateTensors
|
|
51
53
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
52
54
|
shard_put, transfer_state_with_mappings)
|
|
53
55
|
from tpu_inference.runner import utils as runner_utils
|
|
@@ -64,7 +66,7 @@ from tpu_inference.runner.structured_decoding_manager import \
|
|
|
64
66
|
StructuredDecodingManager
|
|
65
67
|
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
66
68
|
from tpu_inference.utils import (device_array, make_optimized_mesh,
|
|
67
|
-
time_function)
|
|
69
|
+
time_function, to_torch_dtype)
|
|
68
70
|
|
|
69
71
|
logger = init_logger(__name__)
|
|
70
72
|
|
|
@@ -78,17 +80,6 @@ DUMMY_METADATA = AttentionMetadata(
|
|
|
78
80
|
request_distribution=[0, 0, 0],
|
|
79
81
|
)
|
|
80
82
|
|
|
81
|
-
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
82
|
-
"half": torch.half,
|
|
83
|
-
"bfloat16": torch.bfloat16,
|
|
84
|
-
"float": torch.float,
|
|
85
|
-
"fp8": torch.float8_e4m3fn,
|
|
86
|
-
"fp8_e4m3": torch.float8_e4m3fn,
|
|
87
|
-
"fp8_e5m2": torch.float8_e5m2,
|
|
88
|
-
"int8": torch.int8,
|
|
89
|
-
"uint8": torch.uint8,
|
|
90
|
-
}
|
|
91
|
-
|
|
92
83
|
|
|
93
84
|
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
94
85
|
"""Holds asynchronous model output specifically from a TPU runner.
|
|
@@ -153,6 +144,7 @@ class ExecuteModelState:
|
|
|
153
144
|
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
|
154
145
|
kv_connector_output: Optional[KVConnectorOutput]
|
|
155
146
|
logits_indices_selector: Optional[List[int]] = None
|
|
147
|
+
padded_num_reqs: Optional[int] = None
|
|
156
148
|
|
|
157
149
|
|
|
158
150
|
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
|
|
@@ -190,18 +182,28 @@ def _substitute_placeholder_token(
|
|
|
190
182
|
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
|
|
191
183
|
|
|
192
184
|
|
|
193
|
-
def
|
|
185
|
+
def _jax_logprobs_to_lists(logprobs_tensors,
|
|
186
|
+
logits_indices_selector=None,
|
|
187
|
+
cu_num_generated_tokens=None):
|
|
188
|
+
"""Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
|
|
189
|
+
log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
|
|
190
|
+
logprobs_list = logprobs_tensors.logprobs.tolist()
|
|
191
|
+
selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
|
|
192
|
+
|
|
193
|
+
if logits_indices_selector is not None:
|
|
194
|
+
log_token_ids_list = [
|
|
195
|
+
log_token_ids_list[i] for i in logits_indices_selector
|
|
196
|
+
]
|
|
197
|
+
logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
|
|
198
|
+
selected_token_ranks_list = [
|
|
199
|
+
selected_token_ranks_list[i] for i in logits_indices_selector
|
|
200
|
+
]
|
|
201
|
+
|
|
194
202
|
return LogprobsLists(
|
|
195
|
-
logprob_token_ids=
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
|
|
200
|
-
sampled_token_ranks=[
|
|
201
|
-
logprobs_lists.sampled_token_ranks[i]
|
|
202
|
-
for i in logits_indices_selector
|
|
203
|
-
],
|
|
204
|
-
cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
|
|
203
|
+
logprob_token_ids=np.asarray(log_token_ids_list),
|
|
204
|
+
logprobs=np.asarray(logprobs_list),
|
|
205
|
+
sampled_token_ranks=np.asarray(selected_token_ranks_list),
|
|
206
|
+
cu_num_generated_tokens=cu_num_generated_tokens,
|
|
205
207
|
)
|
|
206
208
|
|
|
207
209
|
|
|
@@ -211,6 +213,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
211
213
|
self,
|
|
212
214
|
vllm_config: VllmConfig,
|
|
213
215
|
devices: List[Any],
|
|
216
|
+
rank: int = 0,
|
|
217
|
+
is_first_rank: bool = True,
|
|
218
|
+
is_last_rank: bool = True,
|
|
214
219
|
):
|
|
215
220
|
self.vllm_config = vllm_config
|
|
216
221
|
self.model_config = vllm_config.model_config
|
|
@@ -229,6 +234,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
229
234
|
self.maybe_forbid_compile = runner_utils.ForbidCompile(
|
|
230
235
|
) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
|
|
231
236
|
self.dp_size = self.vllm_config.sharding_config.total_dp_size
|
|
237
|
+
self.rank = rank
|
|
238
|
+
self.is_first_rank = is_first_rank
|
|
239
|
+
self.is_last_rank = is_last_rank
|
|
232
240
|
|
|
233
241
|
self._init_random()
|
|
234
242
|
self._init_mesh()
|
|
@@ -239,31 +247,21 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
239
247
|
|
|
240
248
|
# Delegate functions to specific manager classes.
|
|
241
249
|
self.compilation_manager = CompilationManager(self)
|
|
242
|
-
self.
|
|
243
|
-
|
|
250
|
+
if self.is_last_rank:
|
|
251
|
+
self.speculative_decoding_manager = SpeculativeDecodingManager(
|
|
252
|
+
self)
|
|
253
|
+
self.structured_decoding_manager = StructuredDecodingManager(self)
|
|
244
254
|
self.kv_cache_manager = KVCacheManager(self)
|
|
245
255
|
self.mm_manager = MultiModalManager(self)
|
|
246
256
|
self.persistent_batch_manager = PersistentBatchManager(
|
|
247
257
|
self.requests, self.input_batch, self.encoder_cache,
|
|
248
|
-
self.uses_mrope, self.model_config)
|
|
258
|
+
self.uses_mrope, self.model_config, self.is_last_rank)
|
|
249
259
|
self.lora_utils = LoraUtils(self)
|
|
250
260
|
|
|
251
|
-
|
|
252
|
-
if
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
|
256
|
-
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
|
|
257
|
-
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
|
|
258
|
-
elif isinstance(model_dtype, torch.dtype):
|
|
259
|
-
self.kv_cache_dtype = model_dtype
|
|
260
|
-
else:
|
|
261
|
-
raise ValueError(
|
|
262
|
-
"KV cache is unsupported for model_dtype of %s",
|
|
263
|
-
model_dtype)
|
|
264
|
-
else:
|
|
265
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
|
|
266
|
-
cache_config.cache_dtype]
|
|
261
|
+
cache_dtype = self.cache_config.cache_dtype
|
|
262
|
+
if cache_dtype == "auto":
|
|
263
|
+
cache_dtype = self.dtype
|
|
264
|
+
self.kv_cache_dtype = to_torch_dtype(cache_dtype)
|
|
267
265
|
|
|
268
266
|
self._pre_async_results: AsyncPreResults | None = None
|
|
269
267
|
self._substitute_placeholder_token_fn = _substitute_placeholder_token
|
|
@@ -277,7 +275,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
277
275
|
self.rng_key = jax.random.key(self.model_config.seed)
|
|
278
276
|
|
|
279
277
|
def _init_mesh(self) -> None:
|
|
280
|
-
if
|
|
278
|
+
if envs.NEW_MODEL_DESIGN:
|
|
281
279
|
self.mesh = self._create_new_model_mesh()
|
|
282
280
|
else:
|
|
283
281
|
# NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
|
|
@@ -288,7 +286,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
288
286
|
logger.info(f"Init mesh | mesh={self.mesh}")
|
|
289
287
|
|
|
290
288
|
def _create_new_model_mesh(self) -> jax.sharding.Mesh:
|
|
291
|
-
num_slices =
|
|
289
|
+
num_slices = envs.NUM_SLICES
|
|
292
290
|
|
|
293
291
|
logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
|
|
294
292
|
f"num_slices={num_slices}")
|
|
@@ -357,7 +355,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
357
355
|
devices=self.devices)
|
|
358
356
|
|
|
359
357
|
def _init_phased_profiling(self) -> None:
|
|
360
|
-
self.phased_profiling_dir =
|
|
358
|
+
self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
|
|
361
359
|
self.phase_based_profiler = None
|
|
362
360
|
if self.phased_profiling_dir:
|
|
363
361
|
self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
|
|
@@ -399,7 +397,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
399
397
|
min_token_size=max(16, self.dp_size),
|
|
400
398
|
max_token_size=scheduler_config.max_num_batched_tokens *
|
|
401
399
|
self.dp_size,
|
|
402
|
-
padding_gap=
|
|
400
|
+
padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
|
403
401
|
self.num_tokens_paddings_per_dp = [
|
|
404
402
|
padding // self.dp_size for padding in self.num_tokens_paddings
|
|
405
403
|
]
|
|
@@ -423,8 +421,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
423
421
|
|
|
424
422
|
self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
425
423
|
self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
426
|
-
self.
|
|
427
|
-
|
|
424
|
+
# Note: self.input_batch and self.block_tables_cpu are both initialized
|
|
425
|
+
# with only 1 block_size. For hybrid kv cache, it will be re-init
|
|
426
|
+
# in kv_cache_manager's maybe_reinitialize_input_batch.
|
|
427
|
+
self.block_tables_cpu = [
|
|
428
|
+
np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
|
|
429
|
+
dtype=np.int32)
|
|
430
|
+
]
|
|
431
|
+
|
|
428
432
|
self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
|
|
429
433
|
dtype=np.int32)
|
|
430
434
|
self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
|
|
@@ -458,9 +462,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
458
462
|
|
|
459
463
|
# tensors for structured decoding
|
|
460
464
|
self.vocab_size = self.model_config.get_vocab_size()
|
|
461
|
-
if self.lora_config is not None:
|
|
462
|
-
# lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
|
|
463
|
-
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
|
464
465
|
self.grammar_bitmask_cpu = np.zeros(
|
|
465
466
|
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
|
|
466
467
|
dtype=np.int32,
|
|
@@ -505,9 +506,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
505
506
|
|
|
506
507
|
self.rng_params_for_sampling = nnx.Rngs(
|
|
507
508
|
jax.random.key(self.model_config.seed)).params()
|
|
508
|
-
self.is_multimodal_model = (
|
|
509
|
-
|
|
510
|
-
|
|
509
|
+
self.is_multimodal_model = (
|
|
510
|
+
self.model_config.is_multimodal_model
|
|
511
|
+
and self.get_multimodal_embeddings_fn is not None and hasattr(
|
|
512
|
+
self.model_config.hf_config, "architectures"
|
|
513
|
+
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
514
|
+
and len(self.model_config.hf_config.architectures) >= 1
|
|
515
|
+
and self.model_config.hf_config.architectures[0]
|
|
516
|
+
!= "Llama4ForConditionalGeneration")
|
|
511
517
|
|
|
512
518
|
logger.info(f"Init model | "
|
|
513
519
|
f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
|
|
@@ -520,6 +526,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
520
526
|
|
|
521
527
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
522
528
|
self.kv_cache_config = kv_cache_config
|
|
529
|
+
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
|
|
523
530
|
self.kv_caches = []
|
|
524
531
|
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
|
|
525
532
|
if has_kv_transfer_group():
|
|
@@ -532,12 +539,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
532
539
|
def execute_model(
|
|
533
540
|
self,
|
|
534
541
|
scheduler_output: "VllmSchedulerOutput",
|
|
535
|
-
intermediate_tensors: Optional[
|
|
536
|
-
) -> ModelRunnerOutput | None:
|
|
542
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
543
|
+
) -> ModelRunnerOutput | JaxIntermediateTensors | None:
|
|
537
544
|
if self.execute_model_state is not None:
|
|
538
545
|
raise RuntimeError("State error: sample_tokens() must be called "
|
|
539
546
|
"after execute_model() returns None.")
|
|
540
|
-
_, output = self._execute_model(scheduler_output)
|
|
547
|
+
_, output = self._execute_model(scheduler_output, intermediate_tensors)
|
|
541
548
|
return output
|
|
542
549
|
|
|
543
550
|
def sample_tokens(
|
|
@@ -550,16 +557,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
550
557
|
|
|
551
558
|
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
552
559
|
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
553
|
-
logits_indices_selector
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
560
|
+
logits_indices_selector,
|
|
561
|
+
padded_num_reqs) = (self.execute_model_state.scheduler_output,
|
|
562
|
+
self.execute_model_state.attn_metadata,
|
|
563
|
+
self.execute_model_state.input_ids,
|
|
564
|
+
self.execute_model_state.hidden_states,
|
|
565
|
+
self.execute_model_state.logits,
|
|
566
|
+
self.execute_model_state.aux_hidden_states,
|
|
567
|
+
self.execute_model_state.spec_decode_metadata,
|
|
568
|
+
self.execute_model_state.kv_connector_output,
|
|
569
|
+
self.execute_model_state.logits_indices_selector,
|
|
570
|
+
self.execute_model_state.padded_num_reqs)
|
|
563
571
|
self.execute_model_state = None
|
|
564
572
|
|
|
565
573
|
if grammar_output is not None:
|
|
@@ -573,12 +581,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
573
581
|
logits,
|
|
574
582
|
arange,
|
|
575
583
|
)
|
|
576
|
-
return self._sample_from_logits(
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
kv_connector_output,
|
|
581
|
-
logits_indices_selector)
|
|
584
|
+
return self._sample_from_logits(
|
|
585
|
+
scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
586
|
+
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
587
|
+
logits_indices_selector, padded_num_reqs)
|
|
582
588
|
|
|
583
589
|
def _modify_prev_results(self):
|
|
584
590
|
# If copy to host has not been done, we just wait.
|
|
@@ -664,7 +670,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
664
670
|
def _execute_model(
|
|
665
671
|
self,
|
|
666
672
|
scheduler_output: "VllmSchedulerOutput",
|
|
667
|
-
|
|
673
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
674
|
+
) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
|
|
675
|
+
| None]:
|
|
668
676
|
self.persistent_batch_manager.update_states(
|
|
669
677
|
scheduler_output, self.get_mrope_input_positions_fn)
|
|
670
678
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
@@ -687,13 +695,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
687
695
|
# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
|
|
688
696
|
(
|
|
689
697
|
input_ids,
|
|
698
|
+
input_positions,
|
|
690
699
|
attn_metadata,
|
|
691
700
|
_,
|
|
692
701
|
logits_indices,
|
|
693
702
|
spec_decode_metadata,
|
|
694
703
|
logits_indices_selector,
|
|
704
|
+
padded_num_reqs,
|
|
695
705
|
) = self._prepare_inputs(scheduler_output)
|
|
696
706
|
|
|
707
|
+
is_llama_guard_4 = (
|
|
708
|
+
hasattr(
|
|
709
|
+
self.model_config.hf_config, "architectures"
|
|
710
|
+
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
711
|
+
and len(self.model_config.hf_config.architectures) >= 1
|
|
712
|
+
and self.model_config.hf_config.architectures[0]
|
|
713
|
+
== "Llama4ForConditionalGeneration")
|
|
714
|
+
|
|
697
715
|
# multi-modal support
|
|
698
716
|
if self.is_multimodal_model:
|
|
699
717
|
# Run the multimodal encoder if any.
|
|
@@ -701,6 +719,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
701
719
|
self.mm_manager.execute_mm_encoder(scheduler_output)
|
|
702
720
|
mm_embeds = self.mm_manager.gather_mm_embeddings(
|
|
703
721
|
scheduler_output, input_ids.shape[0])
|
|
722
|
+
#TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
|
|
723
|
+
elif is_llama_guard_4 and any(
|
|
724
|
+
self.mm_manager.runner.requests[req_id].mm_features
|
|
725
|
+
for req_id in self.mm_manager.runner.input_batch.req_ids):
|
|
726
|
+
raise NotImplementedError(
|
|
727
|
+
"Llama Guard 4 (JAX) currently supports only text inputs. "
|
|
728
|
+
"Multimodal processing not yet implemented.")
|
|
704
729
|
else:
|
|
705
730
|
mm_embeds = []
|
|
706
731
|
|
|
@@ -725,7 +750,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
725
750
|
scheduler_output) as kv_connector_output:
|
|
726
751
|
# NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
|
|
727
752
|
# but one of them would be `None`
|
|
728
|
-
|
|
729
753
|
(self.kv_caches, hidden_states,
|
|
730
754
|
aux_hidden_states) = self.model_fn(
|
|
731
755
|
self.state,
|
|
@@ -733,10 +757,17 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
733
757
|
input_ids,
|
|
734
758
|
attn_metadata,
|
|
735
759
|
inputs_embeds,
|
|
760
|
+
input_positions,
|
|
736
761
|
tuple(self.layer_name_to_kvcache_index.items()),
|
|
737
762
|
lora_metadata,
|
|
763
|
+
intermediate_tensors,
|
|
764
|
+
self.is_first_rank,
|
|
765
|
+
self.is_last_rank,
|
|
738
766
|
)
|
|
739
|
-
|
|
767
|
+
if not get_pp_group().is_last_rank:
|
|
768
|
+
assert isinstance(hidden_states, JaxIntermediateTensors)
|
|
769
|
+
hidden_states.kv_connector_output = kv_connector_output
|
|
770
|
+
return attn_metadata, hidden_states
|
|
740
771
|
hidden_states = self._select_from_array_fn(hidden_states,
|
|
741
772
|
logits_indices)
|
|
742
773
|
logits = self.compute_logits_fn(
|
|
@@ -754,7 +785,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
754
785
|
aux_hidden_states=aux_hidden_states,
|
|
755
786
|
spec_decode_metadata=spec_decode_metadata,
|
|
756
787
|
kv_connector_output=kv_connector_output,
|
|
757
|
-
logits_indices_selector=logits_indices_selector
|
|
788
|
+
logits_indices_selector=logits_indices_selector,
|
|
789
|
+
padded_num_reqs=padded_num_reqs)
|
|
758
790
|
return attn_metadata, None
|
|
759
791
|
|
|
760
792
|
def _sample_from_logits(
|
|
@@ -768,23 +800,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
768
800
|
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
|
769
801
|
kv_connector_output: Optional[KVConnectorOutput],
|
|
770
802
|
logits_indices_selector: Optional[List[int]] = None,
|
|
803
|
+
padded_num_reqs: Optional[int] = None,
|
|
771
804
|
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
|
|
772
|
-
padded_num_reqs
|
|
773
|
-
|
|
805
|
+
if padded_num_reqs is None:
|
|
806
|
+
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
|
|
807
|
+
self.input_batch.num_reqs, self.max_num_reqs)
|
|
808
|
+
|
|
809
|
+
sharding = None
|
|
810
|
+
if self.dp_size > 1:
|
|
811
|
+
sharding = NamedSharding(self.mesh,
|
|
812
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
813
|
+
|
|
774
814
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
775
|
-
self.mesh, self.input_batch, padded_num_reqs)
|
|
815
|
+
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
|
|
816
|
+
|
|
817
|
+
# TODO(pooyam): Should we move this to `_prepare_inputs`?
|
|
818
|
+
if tpu_sampling_metadata.do_sampling:
|
|
819
|
+
self.rng_params_for_sampling, step_rng = jax.random.split(
|
|
820
|
+
self.rng_params_for_sampling)
|
|
821
|
+
else:
|
|
822
|
+
step_rng = self.rng_params_for_sampling
|
|
823
|
+
|
|
776
824
|
if spec_decode_metadata is None:
|
|
777
825
|
next_tokens = sample(
|
|
778
|
-
|
|
826
|
+
step_rng,
|
|
779
827
|
self.mesh,
|
|
780
828
|
logits,
|
|
781
829
|
tpu_sampling_metadata,
|
|
782
830
|
)
|
|
783
831
|
else:
|
|
832
|
+
if tpu_sampling_metadata.do_sampling:
|
|
833
|
+
bonus_rng, rejection_rng = jax.random.split(step_rng)
|
|
834
|
+
else:
|
|
835
|
+
bonus_rng = step_rng
|
|
836
|
+
rejection_rng = step_rng
|
|
784
837
|
bonus_logits = self._select_from_array_fn(
|
|
785
838
|
logits, spec_decode_metadata.bonus_logits_indices)
|
|
786
839
|
bonus_token_ids = sample(
|
|
787
|
-
|
|
840
|
+
bonus_rng,
|
|
788
841
|
self.mesh,
|
|
789
842
|
bonus_logits,
|
|
790
843
|
tpu_sampling_metadata,
|
|
@@ -798,7 +851,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
798
851
|
target_logits=target_logits,
|
|
799
852
|
bonus_token_ids=bonus_token_ids,
|
|
800
853
|
sampling_metadata=tpu_sampling_metadata,
|
|
801
|
-
key=
|
|
854
|
+
key=rejection_rng,
|
|
802
855
|
)
|
|
803
856
|
|
|
804
857
|
if tpu_sampling_metadata.logprobs:
|
|
@@ -856,10 +909,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
856
909
|
|
|
857
910
|
if logprobs is not None:
|
|
858
911
|
# Map logprobs back to the pre-dp shuffling order
|
|
859
|
-
logprobs_lists =
|
|
860
|
-
|
|
861
|
-
logprobs_lists = _reorder_logits_indices(
|
|
862
|
-
logprobs_lists, logits_indices_selector)
|
|
912
|
+
logprobs_lists = _jax_logprobs_to_lists(
|
|
913
|
+
logprobs, logits_indices_selector)
|
|
863
914
|
|
|
864
915
|
else:
|
|
865
916
|
logprobs_lists = None
|
|
@@ -929,10 +980,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
929
980
|
|
|
930
981
|
if logprobs is not None:
|
|
931
982
|
# Map logprobs back to the pre-dp shuffling order
|
|
932
|
-
logprobs_lists = logprobs
|
|
933
|
-
|
|
934
|
-
logprobs_lists = _reorder_logits_indices(
|
|
935
|
-
logprobs_lists, logits_indices_selector)
|
|
983
|
+
logprobs_lists = _jax_logprobs_to_lists(logprobs,
|
|
984
|
+
logits_indices_selector)
|
|
936
985
|
else:
|
|
937
986
|
logprobs_lists = None
|
|
938
987
|
|
|
@@ -1280,16 +1329,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1280
1329
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1281
1330
|
padded_total_num_scheduled_tokens]
|
|
1282
1331
|
|
|
1283
|
-
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1284
|
-
for dp_rank in range(dp_size):
|
|
1285
|
-
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1286
|
-
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1287
|
-
|
|
1288
|
-
block_tables[
|
|
1289
|
-
req_offset:req_offset + _num_reqs, :self.
|
|
1290
|
-
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1291
|
-
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1292
|
-
|
|
1293
1332
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
|
|
1294
1333
|
dp_size]
|
|
1295
1334
|
seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
|
|
@@ -1331,20 +1370,59 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1331
1370
|
if self.uses_mrope:
|
|
1332
1371
|
positions = mrope_positions
|
|
1333
1372
|
|
|
1334
|
-
# Convert block_tables to 1D on cpu.
|
|
1335
|
-
block_tables = block_tables.reshape(-1)
|
|
1336
|
-
|
|
1337
1373
|
query_start_loc_cpu = query_start_loc
|
|
1338
1374
|
logits_indices_cpu = logits_indices
|
|
1339
1375
|
seq_lens_cpu = seq_lens
|
|
1340
1376
|
|
|
1341
|
-
(input_ids, positions,
|
|
1342
|
-
|
|
1377
|
+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
|
|
1378
|
+
request_distribution) = device_array(
|
|
1343
1379
|
self.mesh,
|
|
1344
|
-
(input_ids, positions,
|
|
1345
|
-
|
|
1380
|
+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
|
|
1381
|
+
request_distribution),
|
|
1346
1382
|
sharding=data_parallel_attn_sharding,
|
|
1347
1383
|
)
|
|
1384
|
+
|
|
1385
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1386
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
1387
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1388
|
+
self.kv_cache_config.kv_cache_groups):
|
|
1389
|
+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1390
|
+
max_num_reqs]
|
|
1391
|
+
for dp_rank in range(dp_size):
|
|
1392
|
+
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1393
|
+
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1394
|
+
|
|
1395
|
+
block_tables[
|
|
1396
|
+
req_offset:req_offset + _num_reqs, :self.
|
|
1397
|
+
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1398
|
+
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1399
|
+
# Convert block_tables to 1D on cpu.
|
|
1400
|
+
block_tables = block_tables.reshape(-1)
|
|
1401
|
+
block_tables = device_array(
|
|
1402
|
+
self.mesh,
|
|
1403
|
+
(block_tables),
|
|
1404
|
+
sharding=data_parallel_attn_sharding,
|
|
1405
|
+
)
|
|
1406
|
+
|
|
1407
|
+
attention_metadata_gid = AttentionMetadata(
|
|
1408
|
+
input_positions=positions,
|
|
1409
|
+
block_tables=block_tables,
|
|
1410
|
+
seq_lens=seq_lens,
|
|
1411
|
+
query_start_loc=query_start_loc,
|
|
1412
|
+
request_distribution=request_distribution,
|
|
1413
|
+
)
|
|
1414
|
+
|
|
1415
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1416
|
+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1417
|
+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1418
|
+
|
|
1419
|
+
if not self.use_hybrid_kvcache:
|
|
1420
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
1421
|
+
else:
|
|
1422
|
+
for layer_name in kv_cache_group.layer_names:
|
|
1423
|
+
attention_metadata_per_layer[
|
|
1424
|
+
layer_name] = attention_metadata_gid
|
|
1425
|
+
|
|
1348
1426
|
# Async scheduling: substitute placeholder tokens for DP
|
|
1349
1427
|
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
|
|
1350
1428
|
# Collect all token indices that need substitution across all DP ranks
|
|
@@ -1373,25 +1451,19 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1373
1451
|
padded_total_num_scheduled_tokens,
|
|
1374
1452
|
)
|
|
1375
1453
|
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
query_start_loc=query_start_loc,
|
|
1381
|
-
request_distribution=request_distribution,
|
|
1382
|
-
)
|
|
1383
|
-
|
|
1384
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1385
|
-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1386
|
-
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1387
|
-
|
|
1454
|
+
if self.use_hybrid_kvcache:
|
|
1455
|
+
attention_metadata = attention_metadata_per_layer
|
|
1456
|
+
else:
|
|
1457
|
+
attention_metadata = uniform_attention_metadata
|
|
1388
1458
|
return (
|
|
1389
1459
|
input_ids,
|
|
1460
|
+
positions,
|
|
1390
1461
|
attention_metadata,
|
|
1391
1462
|
sampling_metadata,
|
|
1392
1463
|
logits_indices,
|
|
1393
1464
|
spec_decode_metadata,
|
|
1394
1465
|
logits_indices_selector,
|
|
1466
|
+
padded_num_reqs,
|
|
1395
1467
|
)
|
|
1396
1468
|
|
|
1397
1469
|
def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
|
|
@@ -1492,9 +1564,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1492
1564
|
positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
|
|
1493
1565
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1494
1566
|
padded_total_num_scheduled_tokens]
|
|
1495
|
-
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1496
|
-
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
|
1497
|
-
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
|
1498
1567
|
|
|
1499
1568
|
# TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
|
|
1500
1569
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
|
|
@@ -1523,16 +1592,44 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1523
1592
|
self.mesh, self.input_batch, padded_num_reqs)
|
|
1524
1593
|
if self.uses_mrope:
|
|
1525
1594
|
positions = mrope_positions
|
|
1526
|
-
|
|
1527
|
-
# Convert block_tables to 1D on cpu.
|
|
1528
|
-
block_tables = block_tables.reshape(-1)
|
|
1529
|
-
|
|
1530
1595
|
query_start_loc_cpu = query_start_loc
|
|
1531
1596
|
seq_lens_cpu = seq_lens
|
|
1532
|
-
|
|
1597
|
+
|
|
1598
|
+
(input_ids, positions, query_start_loc, seq_lens,
|
|
1533
1599
|
logits_indices, request_distribution) = device_array(
|
|
1534
|
-
self.mesh, (input_ids, positions,
|
|
1535
|
-
|
|
1600
|
+
self.mesh, (input_ids, positions, query_start_loc, seq_lens,
|
|
1601
|
+
logits_indices, request_distribution))
|
|
1602
|
+
|
|
1603
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1604
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
1605
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1606
|
+
self.kv_cache_config.kv_cache_groups):
|
|
1607
|
+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1608
|
+
max_num_reqs]
|
|
1609
|
+
block_tables[:num_reqs] = (
|
|
1610
|
+
self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
|
|
1611
|
+
[:num_reqs])
|
|
1612
|
+
# Convert block_tables to 1D on cpu.
|
|
1613
|
+
block_tables = block_tables.reshape(-1)
|
|
1614
|
+
block_tables = device_array(self.mesh, (block_tables))
|
|
1615
|
+
|
|
1616
|
+
attention_metadata_gid = AttentionMetadata(
|
|
1617
|
+
input_positions=positions,
|
|
1618
|
+
block_tables=block_tables,
|
|
1619
|
+
seq_lens=seq_lens,
|
|
1620
|
+
query_start_loc=query_start_loc,
|
|
1621
|
+
request_distribution=request_distribution)
|
|
1622
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1623
|
+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1624
|
+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1625
|
+
|
|
1626
|
+
if not self.use_hybrid_kvcache:
|
|
1627
|
+
# all layers share the same attention metadata
|
|
1628
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
1629
|
+
else:
|
|
1630
|
+
for layer_name in kv_cache_group.layer_names:
|
|
1631
|
+
attention_metadata_per_layer[
|
|
1632
|
+
layer_name] = attention_metadata_gid
|
|
1536
1633
|
|
|
1537
1634
|
if self.scheduler_config.async_scheduling and len(
|
|
1538
1635
|
token_in_tpu_cur_input_indices) > 0:
|
|
@@ -1545,20 +1642,15 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1545
1642
|
self.lora_utils.set_active_loras(
|
|
1546
1643
|
num_scheduled_tokens_per_req, total_num_scheduled_tokens,
|
|
1547
1644
|
padded_total_num_scheduled_tokens)
|
|
1548
|
-
|
|
1549
|
-
attention_metadata = AttentionMetadata(
|
|
1550
|
-
input_positions=positions,
|
|
1551
|
-
block_tables=block_tables,
|
|
1552
|
-
seq_lens=seq_lens,
|
|
1553
|
-
query_start_loc=query_start_loc,
|
|
1554
|
-
request_distribution=request_distribution)
|
|
1555
|
-
|
|
1556
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1557
|
-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1558
|
-
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1559
1645
|
logits_indices_selector = None
|
|
1560
|
-
|
|
1561
|
-
|
|
1646
|
+
|
|
1647
|
+
if self.use_hybrid_kvcache:
|
|
1648
|
+
attention_metadata = attention_metadata_per_layer
|
|
1649
|
+
else:
|
|
1650
|
+
attention_metadata = uniform_attention_metadata
|
|
1651
|
+
return (input_ids, positions, attention_metadata, sampling_metadata,
|
|
1652
|
+
logits_indices, spec_decode_metadata, logits_indices_selector,
|
|
1653
|
+
padded_num_reqs)
|
|
1562
1654
|
|
|
1563
1655
|
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1564
1656
|
mm_embeds: list[jax.Array]):
|
|
@@ -1618,3 +1710,35 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1618
1710
|
mappings=mappings,
|
|
1619
1711
|
transpose_keys=transpose_keys,
|
|
1620
1712
|
shard=shard)
|
|
1713
|
+
|
|
1714
|
+
def get_intermediate_tensor_spec(self, num_tokens: int):
|
|
1715
|
+
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
|
|
1716
|
+
jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
|
|
1717
|
+
num_padded_tokens = runner_utils.get_padded_token_len(
|
|
1718
|
+
self.num_tokens_paddings, num_tokens)
|
|
1719
|
+
sharding = NamedSharding(self.mesh, PartitionSpec())
|
|
1720
|
+
hidden_size = self.model_config.get_hidden_size()
|
|
1721
|
+
spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
|
|
1722
|
+
dtype=jax_dtype,
|
|
1723
|
+
sharding=sharding)
|
|
1724
|
+
tensor_spec = {"hidden_states": spec, "residual": spec}
|
|
1725
|
+
return tensor_spec
|
|
1726
|
+
|
|
1727
|
+
def get_uuid_for_jax_transfer(self,
|
|
1728
|
+
scheduler_output: "VllmSchedulerOutput",
|
|
1729
|
+
rank: int, step: int) -> int:
|
|
1730
|
+
'''
|
|
1731
|
+
Get a uuid for jax.transfer, here we use the hash of
|
|
1732
|
+
scheduler_output + counter_step + sender's rank
|
|
1733
|
+
'''
|
|
1734
|
+
scheduler_output_str = ""
|
|
1735
|
+
if not scheduler_output.num_scheduled_tokens:
|
|
1736
|
+
scheduler_output_str = "empty_batch"
|
|
1737
|
+
else:
|
|
1738
|
+
scheduler_output_str = str(
|
|
1739
|
+
sorted(scheduler_output.num_scheduled_tokens.items()))
|
|
1740
|
+
unique_str = f'{scheduler_output_str} {step} {rank}'
|
|
1741
|
+
import hashlib
|
|
1742
|
+
hasher = hashlib.sha1()
|
|
1743
|
+
hasher.update(unique_str.encode('utf-8'))
|
|
1744
|
+
return int.from_bytes(hasher.digest()[:8], 'big')
|