tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,20 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import copy
|
|
2
16
|
import functools
|
|
3
|
-
import
|
|
17
|
+
import logging
|
|
4
18
|
import random
|
|
5
19
|
from contextlib import nullcontext
|
|
6
20
|
from dataclasses import dataclass
|
|
@@ -10,17 +24,15 @@ import jax
|
|
|
10
24
|
import jax.numpy as jnp
|
|
11
25
|
import jaxtyping
|
|
12
26
|
import numpy as np
|
|
13
|
-
import
|
|
14
|
-
import vllm.envs as envs
|
|
27
|
+
import vllm.envs as vllm_envs
|
|
15
28
|
from flax import nnx
|
|
16
29
|
from jax.experimental import mesh_utils
|
|
17
30
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
18
|
-
from torchax.ops.mappings import j2t_dtype
|
|
19
31
|
from vllm.config import VllmConfig
|
|
32
|
+
from vllm.distributed import get_pp_group
|
|
20
33
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
21
34
|
has_kv_transfer_group)
|
|
22
35
|
from vllm.forward_context import set_forward_context
|
|
23
|
-
from vllm.sequence import IntermediateTensors
|
|
24
36
|
from vllm.tasks import SupportedTask
|
|
25
37
|
from vllm.utils.math_utils import cdiv
|
|
26
38
|
from vllm.v1.core.sched.output import GrammarOutput
|
|
@@ -35,6 +47,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
|
35
47
|
KVConnectorModelRunnerMixin
|
|
36
48
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
37
49
|
|
|
50
|
+
import tpu_inference.envs as envs
|
|
38
51
|
from tpu_inference import utils as common_utils
|
|
39
52
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
53
|
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
@@ -48,6 +61,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
|
48
61
|
TPUSupportedSamplingMetadata
|
|
49
62
|
from tpu_inference.logger import init_logger
|
|
50
63
|
from tpu_inference.models.common.model_loader import get_model
|
|
64
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
65
|
+
JaxIntermediateTensors
|
|
51
66
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
52
67
|
shard_put, transfer_state_with_mappings)
|
|
53
68
|
from tpu_inference.runner import utils as runner_utils
|
|
@@ -64,10 +79,12 @@ from tpu_inference.runner.structured_decoding_manager import \
|
|
|
64
79
|
StructuredDecodingManager
|
|
65
80
|
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
66
81
|
from tpu_inference.utils import (device_array, make_optimized_mesh,
|
|
67
|
-
time_function)
|
|
82
|
+
time_function, to_jax_dtype, to_torch_dtype)
|
|
68
83
|
|
|
69
84
|
logger = init_logger(__name__)
|
|
70
85
|
|
|
86
|
+
logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
|
|
87
|
+
|
|
71
88
|
INVALID_TOKEN_ID = -1
|
|
72
89
|
# Smallest output size
|
|
73
90
|
MIN_NUM_SEQS = 8
|
|
@@ -78,17 +95,6 @@ DUMMY_METADATA = AttentionMetadata(
|
|
|
78
95
|
request_distribution=[0, 0, 0],
|
|
79
96
|
)
|
|
80
97
|
|
|
81
|
-
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
82
|
-
"half": torch.half,
|
|
83
|
-
"bfloat16": torch.bfloat16,
|
|
84
|
-
"float": torch.float,
|
|
85
|
-
"fp8": torch.float8_e4m3fn,
|
|
86
|
-
"fp8_e4m3": torch.float8_e4m3fn,
|
|
87
|
-
"fp8_e5m2": torch.float8_e5m2,
|
|
88
|
-
"int8": torch.int8,
|
|
89
|
-
"uint8": torch.uint8,
|
|
90
|
-
}
|
|
91
|
-
|
|
92
98
|
|
|
93
99
|
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
94
100
|
"""Holds asynchronous model output specifically from a TPU runner.
|
|
@@ -243,6 +249,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
243
249
|
self.maybe_forbid_compile = runner_utils.ForbidCompile(
|
|
244
250
|
) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
|
|
245
251
|
self.dp_size = self.vllm_config.sharding_config.total_dp_size
|
|
252
|
+
self.rank = rank
|
|
253
|
+
self.is_first_rank = is_first_rank
|
|
254
|
+
self.is_last_rank = is_last_rank
|
|
246
255
|
|
|
247
256
|
self._init_random()
|
|
248
257
|
self._init_mesh()
|
|
@@ -253,31 +262,21 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
253
262
|
|
|
254
263
|
# Delegate functions to specific manager classes.
|
|
255
264
|
self.compilation_manager = CompilationManager(self)
|
|
256
|
-
self.
|
|
257
|
-
|
|
265
|
+
if self.is_last_rank:
|
|
266
|
+
self.speculative_decoding_manager = SpeculativeDecodingManager(
|
|
267
|
+
self)
|
|
268
|
+
self.structured_decoding_manager = StructuredDecodingManager(self)
|
|
258
269
|
self.kv_cache_manager = KVCacheManager(self)
|
|
259
270
|
self.mm_manager = MultiModalManager(self)
|
|
260
271
|
self.persistent_batch_manager = PersistentBatchManager(
|
|
261
272
|
self.requests, self.input_batch, self.encoder_cache,
|
|
262
|
-
self.uses_mrope, self.model_config)
|
|
273
|
+
self.uses_mrope, self.model_config, self.is_last_rank)
|
|
263
274
|
self.lora_utils = LoraUtils(self)
|
|
264
275
|
|
|
265
|
-
|
|
266
|
-
if
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
|
270
|
-
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
|
|
271
|
-
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
|
|
272
|
-
elif isinstance(model_dtype, torch.dtype):
|
|
273
|
-
self.kv_cache_dtype = model_dtype
|
|
274
|
-
else:
|
|
275
|
-
raise ValueError(
|
|
276
|
-
"KV cache is unsupported for model_dtype of %s",
|
|
277
|
-
model_dtype)
|
|
278
|
-
else:
|
|
279
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
|
|
280
|
-
cache_config.cache_dtype]
|
|
276
|
+
cache_dtype = self.cache_config.cache_dtype
|
|
277
|
+
if cache_dtype == "auto":
|
|
278
|
+
cache_dtype = self.dtype
|
|
279
|
+
self.kv_cache_dtype = to_torch_dtype(cache_dtype)
|
|
281
280
|
|
|
282
281
|
self._pre_async_results: AsyncPreResults | None = None
|
|
283
282
|
self._substitute_placeholder_token_fn = _substitute_placeholder_token
|
|
@@ -291,7 +290,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
291
290
|
self.rng_key = jax.random.key(self.model_config.seed)
|
|
292
291
|
|
|
293
292
|
def _init_mesh(self) -> None:
|
|
294
|
-
if
|
|
293
|
+
if envs.NEW_MODEL_DESIGN:
|
|
295
294
|
self.mesh = self._create_new_model_mesh()
|
|
296
295
|
else:
|
|
297
296
|
# NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
|
|
@@ -302,7 +301,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
302
301
|
logger.info(f"Init mesh | mesh={self.mesh}")
|
|
303
302
|
|
|
304
303
|
def _create_new_model_mesh(self) -> jax.sharding.Mesh:
|
|
305
|
-
num_slices =
|
|
304
|
+
num_slices = envs.NUM_SLICES
|
|
306
305
|
|
|
307
306
|
logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
|
|
308
307
|
f"num_slices={num_slices}")
|
|
@@ -371,7 +370,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
371
370
|
devices=self.devices)
|
|
372
371
|
|
|
373
372
|
def _init_phased_profiling(self) -> None:
|
|
374
|
-
self.phased_profiling_dir =
|
|
373
|
+
self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
|
|
375
374
|
self.phase_based_profiler = None
|
|
376
375
|
if self.phased_profiling_dir:
|
|
377
376
|
self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
|
|
@@ -413,7 +412,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
413
412
|
min_token_size=max(16, self.dp_size),
|
|
414
413
|
max_token_size=scheduler_config.max_num_batched_tokens *
|
|
415
414
|
self.dp_size,
|
|
416
|
-
padding_gap=
|
|
415
|
+
padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
|
417
416
|
self.num_tokens_paddings_per_dp = [
|
|
418
417
|
padding // self.dp_size for padding in self.num_tokens_paddings
|
|
419
418
|
]
|
|
@@ -509,10 +508,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
509
508
|
multimodal_fns = multimodal_fns or {}
|
|
510
509
|
self.precompile_vision_encoder_fn = multimodal_fns.get(
|
|
511
510
|
"precompile_vision_encoder_fn", None)
|
|
512
|
-
self.
|
|
513
|
-
|
|
514
|
-
self.
|
|
515
|
-
|
|
511
|
+
self.embed_multimodal_fn = multimodal_fns.get("embed_multimodal_fn",
|
|
512
|
+
None)
|
|
513
|
+
self.embed_input_ids_fn = multimodal_fns.get("embed_input_ids_fn",
|
|
514
|
+
None)
|
|
516
515
|
self.get_mrope_input_positions_fn = multimodal_fns.get(
|
|
517
516
|
"get_mrope_input_positions_fn", None)
|
|
518
517
|
|
|
@@ -524,7 +523,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
524
523
|
jax.random.key(self.model_config.seed)).params()
|
|
525
524
|
self.is_multimodal_model = (
|
|
526
525
|
self.model_config.is_multimodal_model
|
|
527
|
-
and self.
|
|
526
|
+
and self.embed_multimodal_fn is not None and hasattr(
|
|
528
527
|
self.model_config.hf_config, "architectures"
|
|
529
528
|
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
530
529
|
and len(self.model_config.hf_config.architectures) >= 1
|
|
@@ -540,7 +539,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
540
539
|
def get_kv_cache_spec(self):
|
|
541
540
|
return self.kv_cache_manager.get_kv_cache_spec()
|
|
542
541
|
|
|
543
|
-
def initialize_kv_cache(self,
|
|
542
|
+
def initialize_kv_cache(self,
|
|
543
|
+
kv_cache_config: KVCacheConfig,
|
|
544
|
+
topology_order_id: int = 0) -> None:
|
|
545
|
+
self.topology_order_id = topology_order_id
|
|
544
546
|
self.kv_cache_config = kv_cache_config
|
|
545
547
|
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
|
|
546
548
|
self.kv_caches = []
|
|
@@ -555,12 +557,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
555
557
|
def execute_model(
|
|
556
558
|
self,
|
|
557
559
|
scheduler_output: "VllmSchedulerOutput",
|
|
558
|
-
intermediate_tensors: Optional[
|
|
559
|
-
) -> ModelRunnerOutput | None:
|
|
560
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
561
|
+
) -> ModelRunnerOutput | JaxIntermediateTensors | None:
|
|
560
562
|
if self.execute_model_state is not None:
|
|
561
563
|
raise RuntimeError("State error: sample_tokens() must be called "
|
|
562
564
|
"after execute_model() returns None.")
|
|
563
|
-
_, output = self._execute_model(scheduler_output)
|
|
565
|
+
_, output = self._execute_model(scheduler_output, intermediate_tensors)
|
|
564
566
|
return output
|
|
565
567
|
|
|
566
568
|
def sample_tokens(
|
|
@@ -686,7 +688,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
686
688
|
def _execute_model(
|
|
687
689
|
self,
|
|
688
690
|
scheduler_output: "VllmSchedulerOutput",
|
|
689
|
-
|
|
691
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
692
|
+
) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
|
|
693
|
+
| None]:
|
|
690
694
|
self.persistent_batch_manager.update_states(
|
|
691
695
|
scheduler_output, self.get_mrope_input_positions_fn)
|
|
692
696
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
@@ -764,7 +768,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
764
768
|
scheduler_output) as kv_connector_output:
|
|
765
769
|
# NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
|
|
766
770
|
# but one of them would be `None`
|
|
767
|
-
|
|
768
771
|
(self.kv_caches, hidden_states,
|
|
769
772
|
aux_hidden_states) = self.model_fn(
|
|
770
773
|
self.state,
|
|
@@ -775,8 +778,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
775
778
|
input_positions,
|
|
776
779
|
tuple(self.layer_name_to_kvcache_index.items()),
|
|
777
780
|
lora_metadata,
|
|
781
|
+
intermediate_tensors,
|
|
782
|
+
self.is_first_rank,
|
|
783
|
+
self.is_last_rank,
|
|
778
784
|
)
|
|
779
|
-
|
|
785
|
+
if not get_pp_group().is_last_rank:
|
|
786
|
+
assert isinstance(hidden_states, JaxIntermediateTensors)
|
|
787
|
+
hidden_states.kv_connector_output = kv_connector_output
|
|
788
|
+
return attn_metadata, hidden_states
|
|
780
789
|
hidden_states = self._select_from_array_fn(hidden_states,
|
|
781
790
|
logits_indices)
|
|
782
791
|
logits = self.compute_logits_fn(
|
|
@@ -818,7 +827,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
818
827
|
sharding = None
|
|
819
828
|
if self.dp_size > 1:
|
|
820
829
|
sharding = NamedSharding(self.mesh,
|
|
821
|
-
PartitionSpec(ShardingAxisName.
|
|
830
|
+
PartitionSpec(ShardingAxisName.MLP_DATA))
|
|
822
831
|
|
|
823
832
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
824
833
|
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
|
|
@@ -1345,7 +1354,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1345
1354
|
_request_distribution = []
|
|
1346
1355
|
for dp_rank in range(dp_size):
|
|
1347
1356
|
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1348
|
-
|
|
1357
|
+
# The batch has been reordered by _reorder_batch so decode requests come first
|
|
1358
|
+
# Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
|
|
1359
|
+
num_decode_in_dp_rank = 0
|
|
1360
|
+
for req_id in req_ids_dp[dp_rank]:
|
|
1361
|
+
if scheduler_output.num_scheduled_tokens[req_id] == 1:
|
|
1362
|
+
num_decode_in_dp_rank += 1
|
|
1363
|
+
_request_distribution.append(
|
|
1364
|
+
[num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
|
|
1349
1365
|
request_distribution = np.array(_request_distribution).ravel()
|
|
1350
1366
|
|
|
1351
1367
|
use_spec_decode = len(
|
|
@@ -1374,7 +1390,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1374
1390
|
self.mesh,
|
|
1375
1391
|
self.input_batch,
|
|
1376
1392
|
padded_num_reqs,
|
|
1377
|
-
sharding=
|
|
1393
|
+
sharding=NamedSharding(self.mesh,
|
|
1394
|
+
PartitionSpec(ShardingAxisName.MLP_DATA)),
|
|
1378
1395
|
)
|
|
1379
1396
|
if self.uses_mrope:
|
|
1380
1397
|
positions = mrope_positions
|
|
@@ -1404,7 +1421,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1404
1421
|
block_tables[
|
|
1405
1422
|
req_offset:req_offset + _num_reqs, :self.
|
|
1406
1423
|
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1407
|
-
|
|
1424
|
+
kv_cache_gid].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1408
1425
|
# Convert block_tables to 1D on cpu.
|
|
1409
1426
|
block_tables = block_tables.reshape(-1)
|
|
1410
1427
|
block_tables = device_array(
|
|
@@ -1664,7 +1681,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1664
1681
|
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1665
1682
|
mm_embeds: list[jax.Array]):
|
|
1666
1683
|
if self.is_multimodal_model:
|
|
1667
|
-
inputs_embeds = self.
|
|
1684
|
+
inputs_embeds = self.embed_input_ids_fn(
|
|
1668
1685
|
self.state,
|
|
1669
1686
|
input_ids,
|
|
1670
1687
|
mm_embeds,
|
|
@@ -1719,3 +1736,34 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1719
1736
|
mappings=mappings,
|
|
1720
1737
|
transpose_keys=transpose_keys,
|
|
1721
1738
|
shard=shard)
|
|
1739
|
+
|
|
1740
|
+
def get_intermediate_tensor_spec(self, num_tokens: int):
|
|
1741
|
+
jax_dtype = to_jax_dtype(self.dtype)
|
|
1742
|
+
num_padded_tokens = runner_utils.get_padded_token_len(
|
|
1743
|
+
self.num_tokens_paddings, num_tokens)
|
|
1744
|
+
sharding = NamedSharding(self.mesh, PartitionSpec())
|
|
1745
|
+
hidden_size = self.model_config.get_hidden_size()
|
|
1746
|
+
spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
|
|
1747
|
+
dtype=jax_dtype,
|
|
1748
|
+
sharding=sharding)
|
|
1749
|
+
tensor_spec = {"hidden_states": spec, "residual": spec}
|
|
1750
|
+
return tensor_spec
|
|
1751
|
+
|
|
1752
|
+
def get_uuid_for_jax_transfer(self,
|
|
1753
|
+
scheduler_output: "VllmSchedulerOutput",
|
|
1754
|
+
rank: int, step: int) -> int:
|
|
1755
|
+
'''
|
|
1756
|
+
Get a uuid for jax.transfer, here we use the hash of
|
|
1757
|
+
scheduler_output + counter_step + sender's rank
|
|
1758
|
+
'''
|
|
1759
|
+
scheduler_output_str = ""
|
|
1760
|
+
if not scheduler_output.num_scheduled_tokens:
|
|
1761
|
+
scheduler_output_str = "empty_batch"
|
|
1762
|
+
else:
|
|
1763
|
+
scheduler_output_str = str(
|
|
1764
|
+
sorted(scheduler_output.num_scheduled_tokens.items()))
|
|
1765
|
+
unique_str = f'{scheduler_output_str} {step} {rank}'
|
|
1766
|
+
import hashlib
|
|
1767
|
+
hasher = hashlib.sha1()
|
|
1768
|
+
hasher.update(unique_str.encode('utf-8'))
|
|
1769
|
+
return int.from_bytes(hasher.digest()[:8], 'big')
|
tpu_inference/runner/utils.py
CHANGED
|
@@ -15,6 +15,7 @@ import jax
|
|
|
15
15
|
from jax._src.interpreters import pxla
|
|
16
16
|
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
17
17
|
|
|
18
|
+
from tpu_inference import envs
|
|
18
19
|
from tpu_inference.logger import init_logger
|
|
19
20
|
from tpu_inference.runner.input_batch import InputBatch
|
|
20
21
|
|
|
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
|
|
|
306
307
|
InferencePhase.BALANCED: False
|
|
307
308
|
}
|
|
308
309
|
self.default_profiling_options = jax.profiler.ProfileOptions()
|
|
309
|
-
self.default_profiling_options.python_tracer_level =
|
|
310
|
-
"PYTHON_TRACER_LEVEL", 0)
|
|
310
|
+
self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
|
|
311
311
|
|
|
312
312
|
self.current_phase: str = ""
|
|
313
313
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""Implements the Eagle3 proposer for speculative decoding on JAX/TPU."""
|
|
2
15
|
import functools
|
|
3
16
|
from dataclasses import replace
|
|
@@ -6,6 +19,9 @@ from typing import Any, Optional
|
|
|
6
19
|
import jax
|
|
7
20
|
import jax.numpy as jnp
|
|
8
21
|
import numpy as np
|
|
22
|
+
from flax import nnx
|
|
23
|
+
from jax import lax
|
|
24
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
25
|
from vllm.config import VllmConfig
|
|
10
26
|
|
|
11
27
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
@@ -127,6 +143,17 @@ class Eagle3Proposer:
|
|
|
127
143
|
max_num_blocks_per_req)
|
|
128
144
|
new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
|
|
129
145
|
|
|
146
|
+
positions = lax.with_sharding_constraint(
|
|
147
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
148
|
+
clamped_positions = lax.with_sharding_constraint(
|
|
149
|
+
clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
150
|
+
new_seq_lens = lax.with_sharding_constraint(
|
|
151
|
+
new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
152
|
+
query_start_loc = lax.with_sharding_constraint(
|
|
153
|
+
query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
|
|
154
|
+
new_block_tables = lax.with_sharding_constraint(
|
|
155
|
+
new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
156
|
+
|
|
130
157
|
return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
|
|
131
158
|
|
|
132
159
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
@@ -138,6 +165,7 @@ class Eagle3Proposer:
|
|
|
138
165
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
139
166
|
def _prepare_hidden_states_and_input_ids(
|
|
140
167
|
self,
|
|
168
|
+
state: nnx.State,
|
|
141
169
|
aux_hidden_states: tuple[jax.Array, ...],
|
|
142
170
|
query_start_loc: jax.Array,
|
|
143
171
|
target_token_ids: jax.Array,
|
|
@@ -146,7 +174,7 @@ class Eagle3Proposer:
|
|
|
146
174
|
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
|
147
175
|
target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
|
|
148
176
|
target_hidden_states = self.combine_hidden_states_fn(
|
|
149
|
-
|
|
177
|
+
state, target_hidden_states)
|
|
150
178
|
|
|
151
179
|
input_ids, last_token_indices = self._prepare_input_ids(
|
|
152
180
|
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
@@ -193,8 +221,8 @@ class Eagle3Proposer:
|
|
|
193
221
|
block_tables=device_array(
|
|
194
222
|
self.mesh, block_tables))
|
|
195
223
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
196
|
-
aux_hidden_states, attn_metadata.query_start_loc,
|
|
197
|
-
next_token_ids, num_reqs)
|
|
224
|
+
self.state, aux_hidden_states, attn_metadata.query_start_loc,
|
|
225
|
+
input_ids, next_token_ids, num_reqs)
|
|
198
226
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
199
227
|
|
|
200
228
|
# Host copies from the metadata prepared by the runner.
|
|
@@ -258,12 +286,13 @@ class Eagle3Proposer:
|
|
|
258
286
|
|
|
259
287
|
attn_metadata = replace(attn_metadata, block_tables=block_tables)
|
|
260
288
|
return self._filter_token_and_prepare_initial_inputs(
|
|
261
|
-
token_indices, query_start_loc, seq_lens, input_ids,
|
|
289
|
+
self.state, token_indices, query_start_loc, seq_lens, input_ids,
|
|
262
290
|
aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
|
|
263
291
|
|
|
264
292
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
265
293
|
def _filter_token_and_prepare_initial_inputs(
|
|
266
294
|
self,
|
|
295
|
+
state: nnx.State,
|
|
267
296
|
token_indices: jax.Array,
|
|
268
297
|
query_start_loc: jax.Array,
|
|
269
298
|
seq_lens: jax.Array,
|
|
@@ -291,35 +320,51 @@ class Eagle3Proposer:
|
|
|
291
320
|
)
|
|
292
321
|
|
|
293
322
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
294
|
-
[h[token_indices] for h in aux_hidden_states],
|
|
295
|
-
target_token_ids, next_token_ids, num_reqs)
|
|
323
|
+
state, [h[token_indices] for h in aux_hidden_states],
|
|
324
|
+
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
296
325
|
|
|
297
326
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
298
327
|
|
|
299
328
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
300
329
|
def _select_draft_token_ids(
|
|
301
330
|
self,
|
|
331
|
+
state: nnx.State,
|
|
302
332
|
hidden_states: jax.Array,
|
|
303
333
|
last_token_indices: jax.Array,
|
|
304
334
|
) -> jax.Array:
|
|
305
335
|
sample_hidden_states = hidden_states[last_token_indices]
|
|
306
|
-
|
|
336
|
+
sample_hidden_states = lax.with_sharding_constraint(
|
|
337
|
+
sample_hidden_states,
|
|
338
|
+
NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
339
|
+
return self._get_draft_token_ids(state, sample_hidden_states)
|
|
307
340
|
|
|
308
341
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
309
|
-
def _get_draft_token_ids(self,
|
|
342
|
+
def _get_draft_token_ids(self, state: nnx.State,
|
|
343
|
+
hidden_states: jax.Array) -> jax.Array:
|
|
310
344
|
lora_metadata = None
|
|
311
|
-
logits = self.compute_logits_fn(
|
|
312
|
-
|
|
313
|
-
return
|
|
345
|
+
logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
|
|
346
|
+
draft_token_ids = jnp.argmax(logits, axis=-1)
|
|
347
|
+
return lax.with_sharding_constraint(
|
|
348
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
314
349
|
|
|
315
350
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
316
351
|
def _select_inputs_for_loop_speculation(
|
|
317
|
-
self, positions: jax.Array, residual: jax.Array,
|
|
352
|
+
self, state: nnx.State, positions: jax.Array, residual: jax.Array,
|
|
318
353
|
hidden_states: jax.Array,
|
|
319
354
|
last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
355
|
+
positions = positions[last_token_indices]
|
|
356
|
+
residual = residual[last_token_indices]
|
|
357
|
+
draft_token_ids = self._select_draft_token_ids(state, hidden_states,
|
|
358
|
+
last_token_indices)
|
|
359
|
+
|
|
360
|
+
positions = lax.with_sharding_constraint(
|
|
361
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
362
|
+
residual = lax.with_sharding_constraint(
|
|
363
|
+
residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
364
|
+
draft_token_ids = lax.with_sharding_constraint(
|
|
365
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
366
|
+
|
|
367
|
+
return positions, residual, draft_token_ids
|
|
323
368
|
|
|
324
369
|
def propose(
|
|
325
370
|
self,
|
|
@@ -346,11 +391,11 @@ class Eagle3Proposer:
|
|
|
346
391
|
|
|
347
392
|
if self.num_speculative_tokens == 1:
|
|
348
393
|
return kv_caches, self._select_draft_token_ids(
|
|
349
|
-
hidden_states, last_token_indices)
|
|
394
|
+
self.state, hidden_states, last_token_indices)
|
|
350
395
|
|
|
351
396
|
positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
|
|
352
|
-
attn_metadata.input_positions, residual[0],
|
|
353
|
-
last_token_indices)
|
|
397
|
+
self.state, attn_metadata.input_positions, residual[0],
|
|
398
|
+
hidden_states, last_token_indices)
|
|
354
399
|
|
|
355
400
|
draft_token_ids_list = [draft_token_ids]
|
|
356
401
|
|
|
@@ -375,7 +420,8 @@ class Eagle3Proposer:
|
|
|
375
420
|
attn_metadata,
|
|
376
421
|
)
|
|
377
422
|
hidden_states = residual[0]
|
|
378
|
-
draft_token_ids = self._get_draft_token_ids(
|
|
423
|
+
draft_token_ids = self._get_draft_token_ids(
|
|
424
|
+
self.state, new_hidden_states)
|
|
379
425
|
draft_token_ids_list.append(draft_token_ids)
|
|
380
426
|
|
|
381
427
|
# [batch_size, num_speculative_tokens]
|
tpu_inference/tpu_info.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import glob
|
|
2
16
|
import os
|
|
3
17
|
|