tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,20 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import copy
|
|
2
16
|
import functools
|
|
3
|
-
import
|
|
17
|
+
import logging
|
|
4
18
|
import random
|
|
5
19
|
from contextlib import nullcontext
|
|
6
20
|
from dataclasses import dataclass
|
|
@@ -10,17 +24,15 @@ import jax
|
|
|
10
24
|
import jax.numpy as jnp
|
|
11
25
|
import jaxtyping
|
|
12
26
|
import numpy as np
|
|
13
|
-
import
|
|
14
|
-
import vllm.envs as envs
|
|
27
|
+
import vllm.envs as vllm_envs
|
|
15
28
|
from flax import nnx
|
|
16
29
|
from jax.experimental import mesh_utils
|
|
17
30
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
18
|
-
from torchax.ops.mappings import j2t_dtype
|
|
19
31
|
from vllm.config import VllmConfig
|
|
32
|
+
from vllm.distributed import get_pp_group
|
|
20
33
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
21
34
|
has_kv_transfer_group)
|
|
22
35
|
from vllm.forward_context import set_forward_context
|
|
23
|
-
from vllm.sequence import IntermediateTensors
|
|
24
36
|
from vllm.tasks import SupportedTask
|
|
25
37
|
from vllm.utils.math_utils import cdiv
|
|
26
38
|
from vllm.v1.core.sched.output import GrammarOutput
|
|
@@ -35,6 +47,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
|
35
47
|
KVConnectorModelRunnerMixin
|
|
36
48
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
37
49
|
|
|
50
|
+
import tpu_inference.envs as envs
|
|
38
51
|
from tpu_inference import utils as common_utils
|
|
39
52
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
53
|
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
@@ -48,6 +61,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
|
48
61
|
TPUSupportedSamplingMetadata
|
|
49
62
|
from tpu_inference.logger import init_logger
|
|
50
63
|
from tpu_inference.models.common.model_loader import get_model
|
|
64
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
65
|
+
JaxIntermediateTensors
|
|
51
66
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
52
67
|
shard_put, transfer_state_with_mappings)
|
|
53
68
|
from tpu_inference.runner import utils as runner_utils
|
|
@@ -64,10 +79,12 @@ from tpu_inference.runner.structured_decoding_manager import \
|
|
|
64
79
|
StructuredDecodingManager
|
|
65
80
|
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
66
81
|
from tpu_inference.utils import (device_array, make_optimized_mesh,
|
|
67
|
-
time_function)
|
|
82
|
+
time_function, to_jax_dtype, to_torch_dtype)
|
|
68
83
|
|
|
69
84
|
logger = init_logger(__name__)
|
|
70
85
|
|
|
86
|
+
logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
|
|
87
|
+
|
|
71
88
|
INVALID_TOKEN_ID = -1
|
|
72
89
|
# Smallest output size
|
|
73
90
|
MIN_NUM_SEQS = 8
|
|
@@ -78,17 +95,6 @@ DUMMY_METADATA = AttentionMetadata(
|
|
|
78
95
|
request_distribution=[0, 0, 0],
|
|
79
96
|
)
|
|
80
97
|
|
|
81
|
-
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
82
|
-
"half": torch.half,
|
|
83
|
-
"bfloat16": torch.bfloat16,
|
|
84
|
-
"float": torch.float,
|
|
85
|
-
"fp8": torch.float8_e4m3fn,
|
|
86
|
-
"fp8_e4m3": torch.float8_e4m3fn,
|
|
87
|
-
"fp8_e5m2": torch.float8_e5m2,
|
|
88
|
-
"int8": torch.int8,
|
|
89
|
-
"uint8": torch.uint8,
|
|
90
|
-
}
|
|
91
|
-
|
|
92
98
|
|
|
93
99
|
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
94
100
|
"""Holds asynchronous model output specifically from a TPU runner.
|
|
@@ -243,6 +249,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
243
249
|
self.maybe_forbid_compile = runner_utils.ForbidCompile(
|
|
244
250
|
) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
|
|
245
251
|
self.dp_size = self.vllm_config.sharding_config.total_dp_size
|
|
252
|
+
self.rank = rank
|
|
253
|
+
self.is_first_rank = is_first_rank
|
|
254
|
+
self.is_last_rank = is_last_rank
|
|
246
255
|
|
|
247
256
|
self._init_random()
|
|
248
257
|
self._init_mesh()
|
|
@@ -253,36 +262,29 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
253
262
|
|
|
254
263
|
# Delegate functions to specific manager classes.
|
|
255
264
|
self.compilation_manager = CompilationManager(self)
|
|
256
|
-
self.
|
|
257
|
-
|
|
265
|
+
if self.is_last_rank:
|
|
266
|
+
self.speculative_decoding_manager = SpeculativeDecodingManager(
|
|
267
|
+
self)
|
|
268
|
+
self.structured_decoding_manager = StructuredDecodingManager(self)
|
|
258
269
|
self.kv_cache_manager = KVCacheManager(self)
|
|
259
270
|
self.mm_manager = MultiModalManager(self)
|
|
260
271
|
self.persistent_batch_manager = PersistentBatchManager(
|
|
261
272
|
self.requests, self.input_batch, self.encoder_cache,
|
|
262
|
-
self.uses_mrope, self.model_config)
|
|
273
|
+
self.uses_mrope, self.model_config, self.is_last_rank)
|
|
263
274
|
self.lora_utils = LoraUtils(self)
|
|
264
275
|
|
|
265
|
-
|
|
266
|
-
if
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
|
270
|
-
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
|
|
271
|
-
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
|
|
272
|
-
elif isinstance(model_dtype, torch.dtype):
|
|
273
|
-
self.kv_cache_dtype = model_dtype
|
|
274
|
-
else:
|
|
275
|
-
raise ValueError(
|
|
276
|
-
"KV cache is unsupported for model_dtype of %s",
|
|
277
|
-
model_dtype)
|
|
278
|
-
else:
|
|
279
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
|
|
280
|
-
cache_config.cache_dtype]
|
|
276
|
+
cache_dtype = self.cache_config.cache_dtype
|
|
277
|
+
if cache_dtype == "auto":
|
|
278
|
+
cache_dtype = self.dtype
|
|
279
|
+
self.kv_cache_dtype = to_torch_dtype(cache_dtype)
|
|
281
280
|
|
|
282
281
|
self._pre_async_results: AsyncPreResults | None = None
|
|
283
282
|
self._substitute_placeholder_token_fn = _substitute_placeholder_token
|
|
284
283
|
self.execute_model_state: ExecuteModelState | None = None
|
|
285
284
|
|
|
285
|
+
self.kv_caches: list[jax.Array] = []
|
|
286
|
+
self.layer_name_to_kvcache_index: dict[str, int] = {}
|
|
287
|
+
|
|
286
288
|
def _init_random(self):
|
|
287
289
|
if self.model_config.seed is None:
|
|
288
290
|
self.model_config.seed = 0
|
|
@@ -291,7 +293,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
291
293
|
self.rng_key = jax.random.key(self.model_config.seed)
|
|
292
294
|
|
|
293
295
|
def _init_mesh(self) -> None:
|
|
294
|
-
if
|
|
296
|
+
if envs.NEW_MODEL_DESIGN:
|
|
295
297
|
self.mesh = self._create_new_model_mesh()
|
|
296
298
|
else:
|
|
297
299
|
# NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
|
|
@@ -302,7 +304,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
302
304
|
logger.info(f"Init mesh | mesh={self.mesh}")
|
|
303
305
|
|
|
304
306
|
def _create_new_model_mesh(self) -> jax.sharding.Mesh:
|
|
305
|
-
num_slices =
|
|
307
|
+
num_slices = envs.NUM_SLICES
|
|
306
308
|
|
|
307
309
|
logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
|
|
308
310
|
f"num_slices={num_slices}")
|
|
@@ -371,7 +373,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
371
373
|
devices=self.devices)
|
|
372
374
|
|
|
373
375
|
def _init_phased_profiling(self) -> None:
|
|
374
|
-
self.phased_profiling_dir =
|
|
376
|
+
self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
|
|
375
377
|
self.phase_based_profiler = None
|
|
376
378
|
if self.phased_profiling_dir:
|
|
377
379
|
self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
|
|
@@ -413,7 +415,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
413
415
|
min_token_size=max(16, self.dp_size),
|
|
414
416
|
max_token_size=scheduler_config.max_num_batched_tokens *
|
|
415
417
|
self.dp_size,
|
|
416
|
-
padding_gap=
|
|
418
|
+
padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
|
417
419
|
self.num_tokens_paddings_per_dp = [
|
|
418
420
|
padding // self.dp_size for padding in self.num_tokens_paddings
|
|
419
421
|
]
|
|
@@ -509,10 +511,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
509
511
|
multimodal_fns = multimodal_fns or {}
|
|
510
512
|
self.precompile_vision_encoder_fn = multimodal_fns.get(
|
|
511
513
|
"precompile_vision_encoder_fn", None)
|
|
512
|
-
self.
|
|
513
|
-
|
|
514
|
-
self.
|
|
515
|
-
|
|
514
|
+
self.embed_multimodal_fn = multimodal_fns.get("embed_multimodal_fn",
|
|
515
|
+
None)
|
|
516
|
+
self.embed_input_ids_fn = multimodal_fns.get("embed_input_ids_fn",
|
|
517
|
+
None)
|
|
516
518
|
self.get_mrope_input_positions_fn = multimodal_fns.get(
|
|
517
519
|
"get_mrope_input_positions_fn", None)
|
|
518
520
|
|
|
@@ -524,7 +526,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
524
526
|
jax.random.key(self.model_config.seed)).params()
|
|
525
527
|
self.is_multimodal_model = (
|
|
526
528
|
self.model_config.is_multimodal_model
|
|
527
|
-
and self.
|
|
529
|
+
and self.embed_multimodal_fn is not None and hasattr(
|
|
528
530
|
self.model_config.hf_config, "architectures"
|
|
529
531
|
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
530
532
|
and len(self.model_config.hf_config.architectures) >= 1
|
|
@@ -540,10 +542,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
540
542
|
def get_kv_cache_spec(self):
|
|
541
543
|
return self.kv_cache_manager.get_kv_cache_spec()
|
|
542
544
|
|
|
543
|
-
def initialize_kv_cache(self,
|
|
545
|
+
def initialize_kv_cache(self,
|
|
546
|
+
kv_cache_config: KVCacheConfig,
|
|
547
|
+
topology_order_id: int = 0) -> None:
|
|
548
|
+
self.topology_order_id = topology_order_id
|
|
544
549
|
self.kv_cache_config = kv_cache_config
|
|
545
550
|
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
|
|
546
|
-
self.kv_caches = []
|
|
547
551
|
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
|
|
548
552
|
if has_kv_transfer_group():
|
|
549
553
|
get_kv_transfer_group().register_runner(self)
|
|
@@ -555,12 +559,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
555
559
|
def execute_model(
|
|
556
560
|
self,
|
|
557
561
|
scheduler_output: "VllmSchedulerOutput",
|
|
558
|
-
intermediate_tensors: Optional[
|
|
559
|
-
) -> ModelRunnerOutput | None:
|
|
562
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
563
|
+
) -> ModelRunnerOutput | JaxIntermediateTensors | None:
|
|
560
564
|
if self.execute_model_state is not None:
|
|
561
565
|
raise RuntimeError("State error: sample_tokens() must be called "
|
|
562
566
|
"after execute_model() returns None.")
|
|
563
|
-
_, output = self._execute_model(scheduler_output)
|
|
567
|
+
_, output = self._execute_model(scheduler_output, intermediate_tensors)
|
|
564
568
|
return output
|
|
565
569
|
|
|
566
570
|
def sample_tokens(
|
|
@@ -686,7 +690,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
686
690
|
def _execute_model(
|
|
687
691
|
self,
|
|
688
692
|
scheduler_output: "VllmSchedulerOutput",
|
|
689
|
-
|
|
693
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
694
|
+
) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
|
|
695
|
+
| None]:
|
|
690
696
|
self.persistent_batch_manager.update_states(
|
|
691
697
|
scheduler_output, self.get_mrope_input_positions_fn)
|
|
692
698
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
@@ -764,7 +770,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
764
770
|
scheduler_output) as kv_connector_output:
|
|
765
771
|
# NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
|
|
766
772
|
# but one of them would be `None`
|
|
767
|
-
|
|
768
773
|
(self.kv_caches, hidden_states,
|
|
769
774
|
aux_hidden_states) = self.model_fn(
|
|
770
775
|
self.state,
|
|
@@ -775,8 +780,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
775
780
|
input_positions,
|
|
776
781
|
tuple(self.layer_name_to_kvcache_index.items()),
|
|
777
782
|
lora_metadata,
|
|
783
|
+
intermediate_tensors,
|
|
784
|
+
self.is_first_rank,
|
|
785
|
+
self.is_last_rank,
|
|
778
786
|
)
|
|
779
|
-
|
|
787
|
+
if not get_pp_group().is_last_rank:
|
|
788
|
+
assert isinstance(hidden_states, JaxIntermediateTensors)
|
|
789
|
+
hidden_states.kv_connector_output = kv_connector_output
|
|
790
|
+
return attn_metadata, hidden_states
|
|
780
791
|
hidden_states = self._select_from_array_fn(hidden_states,
|
|
781
792
|
logits_indices)
|
|
782
793
|
logits = self.compute_logits_fn(
|
|
@@ -818,22 +829,35 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
818
829
|
sharding = None
|
|
819
830
|
if self.dp_size > 1:
|
|
820
831
|
sharding = NamedSharding(self.mesh,
|
|
821
|
-
PartitionSpec(ShardingAxisName.
|
|
832
|
+
PartitionSpec(ShardingAxisName.MLP_DATA))
|
|
822
833
|
|
|
823
834
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
824
835
|
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
|
|
836
|
+
|
|
837
|
+
# TODO(pooyam): Should we move this to `_prepare_inputs`?
|
|
838
|
+
if tpu_sampling_metadata.do_sampling:
|
|
839
|
+
self.rng_params_for_sampling, step_rng = jax.random.split(
|
|
840
|
+
self.rng_params_for_sampling)
|
|
841
|
+
else:
|
|
842
|
+
step_rng = self.rng_params_for_sampling
|
|
843
|
+
|
|
825
844
|
if spec_decode_metadata is None:
|
|
826
845
|
next_tokens = sample(
|
|
827
|
-
|
|
846
|
+
step_rng,
|
|
828
847
|
self.mesh,
|
|
829
848
|
logits,
|
|
830
849
|
tpu_sampling_metadata,
|
|
831
850
|
)
|
|
832
851
|
else:
|
|
852
|
+
if tpu_sampling_metadata.do_sampling:
|
|
853
|
+
bonus_rng, rejection_rng = jax.random.split(step_rng)
|
|
854
|
+
else:
|
|
855
|
+
bonus_rng = step_rng
|
|
856
|
+
rejection_rng = step_rng
|
|
833
857
|
bonus_logits = self._select_from_array_fn(
|
|
834
858
|
logits, spec_decode_metadata.bonus_logits_indices)
|
|
835
859
|
bonus_token_ids = sample(
|
|
836
|
-
|
|
860
|
+
bonus_rng,
|
|
837
861
|
self.mesh,
|
|
838
862
|
bonus_logits,
|
|
839
863
|
tpu_sampling_metadata,
|
|
@@ -847,7 +871,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
847
871
|
target_logits=target_logits,
|
|
848
872
|
bonus_token_ids=bonus_token_ids,
|
|
849
873
|
sampling_metadata=tpu_sampling_metadata,
|
|
850
|
-
key=
|
|
874
|
+
key=rejection_rng,
|
|
851
875
|
)
|
|
852
876
|
|
|
853
877
|
if tpu_sampling_metadata.logprobs:
|
|
@@ -1332,7 +1356,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1332
1356
|
_request_distribution = []
|
|
1333
1357
|
for dp_rank in range(dp_size):
|
|
1334
1358
|
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1335
|
-
|
|
1359
|
+
# The batch has been reordered by _reorder_batch so decode requests come first
|
|
1360
|
+
# Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
|
|
1361
|
+
num_decode_in_dp_rank = 0
|
|
1362
|
+
for req_id in req_ids_dp[dp_rank]:
|
|
1363
|
+
if scheduler_output.num_scheduled_tokens[req_id] == 1:
|
|
1364
|
+
num_decode_in_dp_rank += 1
|
|
1365
|
+
_request_distribution.append(
|
|
1366
|
+
[num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
|
|
1336
1367
|
request_distribution = np.array(_request_distribution).ravel()
|
|
1337
1368
|
|
|
1338
1369
|
use_spec_decode = len(
|
|
@@ -1361,7 +1392,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1361
1392
|
self.mesh,
|
|
1362
1393
|
self.input_batch,
|
|
1363
1394
|
padded_num_reqs,
|
|
1364
|
-
sharding=
|
|
1395
|
+
sharding=NamedSharding(self.mesh,
|
|
1396
|
+
PartitionSpec(ShardingAxisName.MLP_DATA)),
|
|
1365
1397
|
)
|
|
1366
1398
|
if self.uses_mrope:
|
|
1367
1399
|
positions = mrope_positions
|
|
@@ -1391,7 +1423,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1391
1423
|
block_tables[
|
|
1392
1424
|
req_offset:req_offset + _num_reqs, :self.
|
|
1393
1425
|
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1394
|
-
|
|
1426
|
+
kv_cache_gid].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1395
1427
|
# Convert block_tables to 1D on cpu.
|
|
1396
1428
|
block_tables = block_tables.reshape(-1)
|
|
1397
1429
|
block_tables = device_array(
|
|
@@ -1651,7 +1683,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1651
1683
|
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1652
1684
|
mm_embeds: list[jax.Array]):
|
|
1653
1685
|
if self.is_multimodal_model:
|
|
1654
|
-
inputs_embeds = self.
|
|
1686
|
+
inputs_embeds = self.embed_input_ids_fn(
|
|
1655
1687
|
self.state,
|
|
1656
1688
|
input_ids,
|
|
1657
1689
|
mm_embeds,
|
|
@@ -1706,3 +1738,34 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1706
1738
|
mappings=mappings,
|
|
1707
1739
|
transpose_keys=transpose_keys,
|
|
1708
1740
|
shard=shard)
|
|
1741
|
+
|
|
1742
|
+
def get_intermediate_tensor_spec(self, num_tokens: int):
|
|
1743
|
+
jax_dtype = to_jax_dtype(self.dtype)
|
|
1744
|
+
num_padded_tokens = runner_utils.get_padded_token_len(
|
|
1745
|
+
self.num_tokens_paddings, num_tokens)
|
|
1746
|
+
sharding = NamedSharding(self.mesh, PartitionSpec())
|
|
1747
|
+
hidden_size = self.model_config.get_hidden_size()
|
|
1748
|
+
spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
|
|
1749
|
+
dtype=jax_dtype,
|
|
1750
|
+
sharding=sharding)
|
|
1751
|
+
tensor_spec = {"hidden_states": spec, "residual": spec}
|
|
1752
|
+
return tensor_spec
|
|
1753
|
+
|
|
1754
|
+
def get_uuid_for_jax_transfer(self,
|
|
1755
|
+
scheduler_output: "VllmSchedulerOutput",
|
|
1756
|
+
rank: int, step: int) -> int:
|
|
1757
|
+
'''
|
|
1758
|
+
Get a uuid for jax.transfer, here we use the hash of
|
|
1759
|
+
scheduler_output + counter_step + sender's rank
|
|
1760
|
+
'''
|
|
1761
|
+
scheduler_output_str = ""
|
|
1762
|
+
if not scheduler_output.num_scheduled_tokens:
|
|
1763
|
+
scheduler_output_str = "empty_batch"
|
|
1764
|
+
else:
|
|
1765
|
+
scheduler_output_str = str(
|
|
1766
|
+
sorted(scheduler_output.num_scheduled_tokens.items()))
|
|
1767
|
+
unique_str = f'{scheduler_output_str} {step} {rank}'
|
|
1768
|
+
import hashlib
|
|
1769
|
+
hasher = hashlib.sha1()
|
|
1770
|
+
hasher.update(unique_str.encode('utf-8'))
|
|
1771
|
+
return int.from_bytes(hasher.digest()[:8], 'big')
|
tpu_inference/runner/utils.py
CHANGED
|
@@ -15,6 +15,7 @@ import jax
|
|
|
15
15
|
from jax._src.interpreters import pxla
|
|
16
16
|
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
17
17
|
|
|
18
|
+
from tpu_inference import envs
|
|
18
19
|
from tpu_inference.logger import init_logger
|
|
19
20
|
from tpu_inference.runner.input_batch import InputBatch
|
|
20
21
|
|
|
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
|
|
|
306
307
|
InferencePhase.BALANCED: False
|
|
307
308
|
}
|
|
308
309
|
self.default_profiling_options = jax.profiler.ProfileOptions()
|
|
309
|
-
self.default_profiling_options.python_tracer_level =
|
|
310
|
-
"PYTHON_TRACER_LEVEL", 0)
|
|
310
|
+
self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
|
|
311
311
|
|
|
312
312
|
self.current_phase: str = ""
|
|
313
313
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""Implements the Eagle3 proposer for speculative decoding on JAX/TPU."""
|
|
2
15
|
import functools
|
|
3
16
|
from dataclasses import replace
|
|
@@ -6,13 +19,19 @@ from typing import Any, Optional
|
|
|
6
19
|
import jax
|
|
7
20
|
import jax.numpy as jnp
|
|
8
21
|
import numpy as np
|
|
22
|
+
from flax import nnx
|
|
23
|
+
from jax import lax
|
|
24
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
25
|
from vllm.config import VllmConfig
|
|
10
26
|
|
|
11
27
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
28
|
+
from tpu_inference.logger import init_logger
|
|
12
29
|
from tpu_inference.models.common.model_loader import get_model
|
|
13
30
|
from tpu_inference.runner import utils as runner_utils
|
|
14
31
|
from tpu_inference.utils import device_array
|
|
15
32
|
|
|
33
|
+
logger = init_logger(__name__)
|
|
34
|
+
|
|
16
35
|
|
|
17
36
|
class Eagle3Proposer:
|
|
18
37
|
"""A proposer for speculative decoding using the Eagle3 method.
|
|
@@ -51,9 +70,22 @@ class Eagle3Proposer:
|
|
|
51
70
|
"""Loads the draft model."""
|
|
52
71
|
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
|
|
53
72
|
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
73
|
+
|
|
74
|
+
draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
|
|
75
|
+
if draft_embed_tokens is None or ~jnp.any(
|
|
76
|
+
draft_embed_tokens.embedding):
|
|
77
|
+
logger.info(
|
|
78
|
+
"Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
|
|
79
|
+
)
|
|
80
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
81
|
+
elif jnp.array_equal(draft_embed_tokens.embedding,
|
|
82
|
+
target_model.model.embed.embedding):
|
|
83
|
+
logger.info(
|
|
84
|
+
"Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
|
|
85
|
+
)
|
|
86
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
87
|
+
else:
|
|
88
|
+
logger.info("Draft model has its own embed_tokens.")
|
|
57
89
|
|
|
58
90
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
59
91
|
def _prepare_input_ids(
|
|
@@ -111,6 +143,17 @@ class Eagle3Proposer:
|
|
|
111
143
|
max_num_blocks_per_req)
|
|
112
144
|
new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
|
|
113
145
|
|
|
146
|
+
positions = lax.with_sharding_constraint(
|
|
147
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
148
|
+
clamped_positions = lax.with_sharding_constraint(
|
|
149
|
+
clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
150
|
+
new_seq_lens = lax.with_sharding_constraint(
|
|
151
|
+
new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
152
|
+
query_start_loc = lax.with_sharding_constraint(
|
|
153
|
+
query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
|
|
154
|
+
new_block_tables = lax.with_sharding_constraint(
|
|
155
|
+
new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
156
|
+
|
|
114
157
|
return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
|
|
115
158
|
|
|
116
159
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
@@ -122,6 +165,7 @@ class Eagle3Proposer:
|
|
|
122
165
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
123
166
|
def _prepare_hidden_states_and_input_ids(
|
|
124
167
|
self,
|
|
168
|
+
state: nnx.State,
|
|
125
169
|
aux_hidden_states: tuple[jax.Array, ...],
|
|
126
170
|
query_start_loc: jax.Array,
|
|
127
171
|
target_token_ids: jax.Array,
|
|
@@ -130,7 +174,7 @@ class Eagle3Proposer:
|
|
|
130
174
|
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
|
131
175
|
target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
|
|
132
176
|
target_hidden_states = self.combine_hidden_states_fn(
|
|
133
|
-
|
|
177
|
+
state, target_hidden_states)
|
|
134
178
|
|
|
135
179
|
input_ids, last_token_indices = self._prepare_input_ids(
|
|
136
180
|
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
@@ -177,8 +221,8 @@ class Eagle3Proposer:
|
|
|
177
221
|
block_tables=device_array(
|
|
178
222
|
self.mesh, block_tables))
|
|
179
223
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
180
|
-
aux_hidden_states, attn_metadata.query_start_loc,
|
|
181
|
-
next_token_ids, num_reqs)
|
|
224
|
+
self.state, aux_hidden_states, attn_metadata.query_start_loc,
|
|
225
|
+
input_ids, next_token_ids, num_reqs)
|
|
182
226
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
183
227
|
|
|
184
228
|
# Host copies from the metadata prepared by the runner.
|
|
@@ -242,12 +286,13 @@ class Eagle3Proposer:
|
|
|
242
286
|
|
|
243
287
|
attn_metadata = replace(attn_metadata, block_tables=block_tables)
|
|
244
288
|
return self._filter_token_and_prepare_initial_inputs(
|
|
245
|
-
token_indices, query_start_loc, seq_lens, input_ids,
|
|
289
|
+
self.state, token_indices, query_start_loc, seq_lens, input_ids,
|
|
246
290
|
aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
|
|
247
291
|
|
|
248
292
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
249
293
|
def _filter_token_and_prepare_initial_inputs(
|
|
250
294
|
self,
|
|
295
|
+
state: nnx.State,
|
|
251
296
|
token_indices: jax.Array,
|
|
252
297
|
query_start_loc: jax.Array,
|
|
253
298
|
seq_lens: jax.Array,
|
|
@@ -275,35 +320,51 @@ class Eagle3Proposer:
|
|
|
275
320
|
)
|
|
276
321
|
|
|
277
322
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
278
|
-
[h[token_indices] for h in aux_hidden_states],
|
|
279
|
-
target_token_ids, next_token_ids, num_reqs)
|
|
323
|
+
state, [h[token_indices] for h in aux_hidden_states],
|
|
324
|
+
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
280
325
|
|
|
281
326
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
282
327
|
|
|
283
328
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
284
329
|
def _select_draft_token_ids(
|
|
285
330
|
self,
|
|
331
|
+
state: nnx.State,
|
|
286
332
|
hidden_states: jax.Array,
|
|
287
333
|
last_token_indices: jax.Array,
|
|
288
334
|
) -> jax.Array:
|
|
289
335
|
sample_hidden_states = hidden_states[last_token_indices]
|
|
290
|
-
|
|
336
|
+
sample_hidden_states = lax.with_sharding_constraint(
|
|
337
|
+
sample_hidden_states,
|
|
338
|
+
NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
339
|
+
return self._get_draft_token_ids(state, sample_hidden_states)
|
|
291
340
|
|
|
292
341
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
293
|
-
def _get_draft_token_ids(self,
|
|
342
|
+
def _get_draft_token_ids(self, state: nnx.State,
|
|
343
|
+
hidden_states: jax.Array) -> jax.Array:
|
|
294
344
|
lora_metadata = None
|
|
295
|
-
logits = self.compute_logits_fn(
|
|
296
|
-
|
|
297
|
-
return
|
|
345
|
+
logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
|
|
346
|
+
draft_token_ids = jnp.argmax(logits, axis=-1)
|
|
347
|
+
return lax.with_sharding_constraint(
|
|
348
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
298
349
|
|
|
299
350
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
300
351
|
def _select_inputs_for_loop_speculation(
|
|
301
|
-
self, positions: jax.Array, residual: jax.Array,
|
|
352
|
+
self, state: nnx.State, positions: jax.Array, residual: jax.Array,
|
|
302
353
|
hidden_states: jax.Array,
|
|
303
354
|
last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
355
|
+
positions = positions[last_token_indices]
|
|
356
|
+
residual = residual[last_token_indices]
|
|
357
|
+
draft_token_ids = self._select_draft_token_ids(state, hidden_states,
|
|
358
|
+
last_token_indices)
|
|
359
|
+
|
|
360
|
+
positions = lax.with_sharding_constraint(
|
|
361
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
362
|
+
residual = lax.with_sharding_constraint(
|
|
363
|
+
residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
364
|
+
draft_token_ids = lax.with_sharding_constraint(
|
|
365
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
366
|
+
|
|
367
|
+
return positions, residual, draft_token_ids
|
|
307
368
|
|
|
308
369
|
def propose(
|
|
309
370
|
self,
|
|
@@ -330,11 +391,11 @@ class Eagle3Proposer:
|
|
|
330
391
|
|
|
331
392
|
if self.num_speculative_tokens == 1:
|
|
332
393
|
return kv_caches, self._select_draft_token_ids(
|
|
333
|
-
hidden_states, last_token_indices)
|
|
394
|
+
self.state, hidden_states, last_token_indices)
|
|
334
395
|
|
|
335
396
|
positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
|
|
336
|
-
attn_metadata.input_positions, residual[0],
|
|
337
|
-
last_token_indices)
|
|
397
|
+
self.state, attn_metadata.input_positions, residual[0],
|
|
398
|
+
hidden_states, last_token_indices)
|
|
338
399
|
|
|
339
400
|
draft_token_ids_list = [draft_token_ids]
|
|
340
401
|
|
|
@@ -359,7 +420,8 @@ class Eagle3Proposer:
|
|
|
359
420
|
attn_metadata,
|
|
360
421
|
)
|
|
361
422
|
hidden_states = residual[0]
|
|
362
|
-
draft_token_ids = self._get_draft_token_ids(
|
|
423
|
+
draft_token_ids = self._get_draft_token_ids(
|
|
424
|
+
self.state, new_hidden_states)
|
|
363
425
|
draft_token_ids_list.append(draft_token_ids)
|
|
364
426
|
|
|
365
427
|
# [batch_size, num_speculative_tokens]
|
tpu_inference/tpu_info.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import glob
|
|
2
16
|
import os
|
|
3
17
|
|