tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
tpu_inference/runner/kv_cache.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Any, List
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -1,5 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
|
-
from typing import TYPE_CHECKING,
|
|
16
|
+
from typing import TYPE_CHECKING, List
|
|
3
17
|
|
|
4
18
|
import jax
|
|
5
19
|
import jax.numpy as jnp
|
|
@@ -198,7 +212,6 @@ class KVCacheManager:
|
|
|
198
212
|
# uniform page size.
|
|
199
213
|
representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
|
200
214
|
page_size_bytes = representative_spec.page_size_bytes
|
|
201
|
-
self.runner.layer_name_to_kvcache_index: Dict[str, int] = {}
|
|
202
215
|
kv_caches = self.runner.kv_caches
|
|
203
216
|
num_blocks_list = []
|
|
204
217
|
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from __future__ import annotations
|
|
2
16
|
|
|
3
17
|
from typing import TYPE_CHECKING
|
|
@@ -7,7 +21,8 @@ from torchax.interop import jax_view
|
|
|
7
21
|
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
|
8
22
|
from vllm.lora.request import LoRARequest
|
|
9
23
|
|
|
10
|
-
from tpu_inference.layers.vllm.
|
|
24
|
+
from tpu_inference.layers.vllm.process_weights.cleanup_sharding import \
|
|
25
|
+
update_lora
|
|
11
26
|
|
|
12
27
|
if TYPE_CHECKING:
|
|
13
28
|
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import TYPE_CHECKING
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -98,7 +112,7 @@ class MultiModalManager:
|
|
|
98
112
|
# encoder outputs.
|
|
99
113
|
encoder_outputs = []
|
|
100
114
|
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
|
101
|
-
mm_kwargs
|
|
115
|
+
mm_kwargs):
|
|
102
116
|
batched_mm_inputs = mm_kwargs_group
|
|
103
117
|
# Convert torch tensors to numpy arrays that JAX can handle.
|
|
104
118
|
if "pixel_values" in batched_mm_inputs and isinstance(
|
|
@@ -134,7 +148,7 @@ class MultiModalManager:
|
|
|
134
148
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
|
135
149
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
|
136
150
|
# depending on the input multimodal items.
|
|
137
|
-
curr_group_outputs = self.runner.
|
|
151
|
+
curr_group_outputs = self.runner.embed_multimodal_fn(
|
|
138
152
|
self.runner.state, image_grid_thw, **batched_mm_inputs)
|
|
139
153
|
|
|
140
154
|
sanity_check_mm_encoder_outputs(
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Dict
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from __future__ import annotations
|
|
2
16
|
|
|
3
17
|
from dataclasses import dataclass
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
from typing import TYPE_CHECKING, Tuple
|
|
3
17
|
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import copy
|
|
2
16
|
import functools
|
|
3
17
|
import logging
|
|
@@ -268,6 +282,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
268
282
|
self._substitute_placeholder_token_fn = _substitute_placeholder_token
|
|
269
283
|
self.execute_model_state: ExecuteModelState | None = None
|
|
270
284
|
|
|
285
|
+
self.kv_caches: list[jax.Array] = []
|
|
286
|
+
self.layer_name_to_kvcache_index: dict[str, int] = {}
|
|
287
|
+
|
|
271
288
|
def _init_random(self):
|
|
272
289
|
if self.model_config.seed is None:
|
|
273
290
|
self.model_config.seed = 0
|
|
@@ -494,10 +511,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
494
511
|
multimodal_fns = multimodal_fns or {}
|
|
495
512
|
self.precompile_vision_encoder_fn = multimodal_fns.get(
|
|
496
513
|
"precompile_vision_encoder_fn", None)
|
|
497
|
-
self.
|
|
498
|
-
|
|
499
|
-
self.
|
|
500
|
-
|
|
514
|
+
self.embed_multimodal_fn = multimodal_fns.get("embed_multimodal_fn",
|
|
515
|
+
None)
|
|
516
|
+
self.embed_input_ids_fn = multimodal_fns.get("embed_input_ids_fn",
|
|
517
|
+
None)
|
|
501
518
|
self.get_mrope_input_positions_fn = multimodal_fns.get(
|
|
502
519
|
"get_mrope_input_positions_fn", None)
|
|
503
520
|
|
|
@@ -509,7 +526,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
509
526
|
jax.random.key(self.model_config.seed)).params()
|
|
510
527
|
self.is_multimodal_model = (
|
|
511
528
|
self.model_config.is_multimodal_model
|
|
512
|
-
and self.
|
|
529
|
+
and self.embed_multimodal_fn is not None and hasattr(
|
|
513
530
|
self.model_config.hf_config, "architectures"
|
|
514
531
|
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
515
532
|
and len(self.model_config.hf_config.architectures) >= 1
|
|
@@ -525,10 +542,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
525
542
|
def get_kv_cache_spec(self):
|
|
526
543
|
return self.kv_cache_manager.get_kv_cache_spec()
|
|
527
544
|
|
|
528
|
-
def initialize_kv_cache(self,
|
|
545
|
+
def initialize_kv_cache(self,
|
|
546
|
+
kv_cache_config: KVCacheConfig,
|
|
547
|
+
topology_order_id: int = 0) -> None:
|
|
548
|
+
self.topology_order_id = topology_order_id
|
|
529
549
|
self.kv_cache_config = kv_cache_config
|
|
530
550
|
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
|
|
531
|
-
self.kv_caches = []
|
|
532
551
|
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
|
|
533
552
|
if has_kv_transfer_group():
|
|
534
553
|
get_kv_transfer_group().register_runner(self)
|
|
@@ -810,7 +829,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
810
829
|
sharding = None
|
|
811
830
|
if self.dp_size > 1:
|
|
812
831
|
sharding = NamedSharding(self.mesh,
|
|
813
|
-
PartitionSpec(ShardingAxisName.
|
|
832
|
+
PartitionSpec(ShardingAxisName.MLP_DATA))
|
|
814
833
|
|
|
815
834
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
816
835
|
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
|
|
@@ -1373,7 +1392,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1373
1392
|
self.mesh,
|
|
1374
1393
|
self.input_batch,
|
|
1375
1394
|
padded_num_reqs,
|
|
1376
|
-
sharding=
|
|
1395
|
+
sharding=NamedSharding(self.mesh,
|
|
1396
|
+
PartitionSpec(ShardingAxisName.MLP_DATA)),
|
|
1377
1397
|
)
|
|
1378
1398
|
if self.uses_mrope:
|
|
1379
1399
|
positions = mrope_positions
|
|
@@ -1663,7 +1683,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1663
1683
|
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1664
1684
|
mm_embeds: list[jax.Array]):
|
|
1665
1685
|
if self.is_multimodal_model:
|
|
1666
|
-
inputs_embeds = self.
|
|
1686
|
+
inputs_embeds = self.embed_input_ids_fn(
|
|
1667
1687
|
self.state,
|
|
1668
1688
|
input_ids,
|
|
1669
1689
|
mm_embeds,
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""Implements the Eagle3 proposer for speculative decoding on JAX/TPU."""
|
|
2
15
|
import functools
|
|
3
16
|
from dataclasses import replace
|
tpu_inference/tpu_info.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import glob
|
|
2
16
|
import os
|
|
3
17
|
|
tpu_inference/utils.py
CHANGED
|
@@ -3,7 +3,7 @@ import time
|
|
|
3
3
|
from collections import defaultdict
|
|
4
4
|
from collections.abc import Sequence
|
|
5
5
|
from functools import wraps
|
|
6
|
-
from typing import Any, Callable, List, Tuple
|
|
6
|
+
from typing import Any, Callable, List, Tuple, Union
|
|
7
7
|
|
|
8
8
|
import jax
|
|
9
9
|
import jax.numpy as jnp
|
|
@@ -283,35 +283,6 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
|
|
283
283
|
return utils.hashing.get_hash_fn_by_name(hash_fn_name)
|
|
284
284
|
|
|
285
285
|
|
|
286
|
-
def quantize_kv(key: jax.Array, value: jax.Array,
|
|
287
|
-
kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
|
|
288
|
-
v_scale: float) -> Tuple[jax.Array, jax.Array]:
|
|
289
|
-
"""
|
|
290
|
-
Quantize the key and value tensors.
|
|
291
|
-
|
|
292
|
-
Args:
|
|
293
|
-
key: The key tensor to quantize.
|
|
294
|
-
value: The value tensor to quantize.
|
|
295
|
-
kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
|
|
296
|
-
q_scale: The scale to quantize the key and value tensors by.
|
|
297
|
-
k_scale: The scale to quantize the key tensor by.
|
|
298
|
-
v_scale: The scale to quantize the value tensor by.
|
|
299
|
-
|
|
300
|
-
Returns:
|
|
301
|
-
Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
|
|
302
|
-
"""
|
|
303
|
-
dtype_info = jnp.finfo(kv_cache_quantized_dtype)
|
|
304
|
-
minval, maxval = float(dtype_info.min), float(dtype_info.max)
|
|
305
|
-
key = key.astype(jnp.float32) / k_scale
|
|
306
|
-
key = jnp.clip(key, minval, maxval)
|
|
307
|
-
key = key.astype(kv_cache_quantized_dtype)
|
|
308
|
-
value = value.astype(jnp.float32) / v_scale
|
|
309
|
-
value = jnp.clip(value, minval, maxval)
|
|
310
|
-
value = value.astype(kv_cache_quantized_dtype)
|
|
311
|
-
|
|
312
|
-
return key, value
|
|
313
|
-
|
|
314
|
-
|
|
315
286
|
def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
316
287
|
"""
|
|
317
288
|
Get the JAX dtype from a string dtype.
|
|
@@ -326,6 +297,36 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
|
326
297
|
return to_jax_dtype(str_dtype)
|
|
327
298
|
|
|
328
299
|
|
|
300
|
+
def get_mesh_shape_product(
|
|
301
|
+
mesh: Mesh,
|
|
302
|
+
axes: Union[str, list[str], None],
|
|
303
|
+
) -> int:
|
|
304
|
+
"""
|
|
305
|
+
Get the product of mesh dimensions for one or more axes.
|
|
306
|
+
|
|
307
|
+
Examples:
|
|
308
|
+
# Single axis (defaults to 1 if not present)
|
|
309
|
+
get_mesh_shape_product(mesh, "model")
|
|
310
|
+
|
|
311
|
+
# Multiple axes - computes product of their sizes
|
|
312
|
+
get_mesh_shape_product(mesh, ["model", "attn_dp"])
|
|
313
|
+
|
|
314
|
+
# None means no sharding on this dimension
|
|
315
|
+
get_mesh_shape_product(mesh, None) # returns 1
|
|
316
|
+
"""
|
|
317
|
+
if axes is None:
|
|
318
|
+
return 1
|
|
319
|
+
|
|
320
|
+
if isinstance(axes, str):
|
|
321
|
+
axes = [axes]
|
|
322
|
+
|
|
323
|
+
product = 1
|
|
324
|
+
for axis in axes:
|
|
325
|
+
product *= mesh.shape.get(axis, 1)
|
|
326
|
+
|
|
327
|
+
return product
|
|
328
|
+
|
|
329
|
+
|
|
329
330
|
def time_function(func):
|
|
330
331
|
"""
|
|
331
332
|
A decorator to measure the execution time of a function.
|
tpu_inference/worker/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -26,8 +26,8 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
|
|
26
26
|
|
|
27
27
|
from tpu_inference import envs, utils
|
|
28
28
|
from tpu_inference.distributed import jax_parallel_state
|
|
29
|
-
from tpu_inference.distributed.utils import (
|
|
30
|
-
|
|
29
|
+
from tpu_inference.distributed.utils import (get_device_topology_order_id,
|
|
30
|
+
get_host_ip, get_kv_transfer_port)
|
|
31
31
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
32
32
|
from tpu_inference.logger import init_logger
|
|
33
33
|
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
@@ -232,9 +232,16 @@ class TPUWorker:
|
|
|
232
232
|
|
|
233
233
|
is_first_rank = True
|
|
234
234
|
is_last_rank = True
|
|
235
|
+
self.topology_order_id = self.rank
|
|
235
236
|
if self.parallel_config.pipeline_parallel_size > 1:
|
|
236
237
|
is_first_rank = self.rank == 0
|
|
237
238
|
is_last_rank = self.rank == self.pp_config.pp_world_size - 1
|
|
239
|
+
else:
|
|
240
|
+
# topology_order_id is used to determine the KV cache
|
|
241
|
+
# mapping between P/D workers
|
|
242
|
+
if multihost_backend == "ray":
|
|
243
|
+
self.topology_order_id = get_device_topology_order_id(
|
|
244
|
+
jax.local_devices(), jax.devices())
|
|
238
245
|
|
|
239
246
|
self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
|
|
240
247
|
self.rank, is_first_rank,
|
|
@@ -243,9 +250,12 @@ class TPUWorker:
|
|
|
243
250
|
f"rank={self.rank} | "
|
|
244
251
|
f"is_first_rank={is_first_rank} | "
|
|
245
252
|
f"is_last_rank={is_last_rank} | "
|
|
246
|
-
f"
|
|
253
|
+
f"topology_order_id={self.topology_order_id} | "
|
|
247
254
|
f"is_driver_worker={self.is_driver_worker} | "
|
|
248
|
-
f"hbm={utils.hbm_usage_gb(self.devices)}GiB"
|
|
255
|
+
f"hbm={utils.hbm_usage_gb(self.devices)}GiB |"
|
|
256
|
+
f"self.devices={self.devices} | "
|
|
257
|
+
f"total devices={jax.devices()} | "
|
|
258
|
+
f"local_devices={jax.local_devices()}")
|
|
249
259
|
vllm_utils.report_usage_stats(self.vllm_config)
|
|
250
260
|
|
|
251
261
|
def initialize_pp_transfer_connect(self):
|
|
@@ -420,13 +430,19 @@ class TPUWorker:
|
|
|
420
430
|
kv_cache_config: KVCacheConfig,
|
|
421
431
|
) -> None:
|
|
422
432
|
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
|
423
|
-
|
|
433
|
+
# Precompile functions with large vocab_size tensors before allocating KV cache to avoid OOM
|
|
434
|
+
if not (envs.SKIP_JAX_PRECOMPILE or
|
|
435
|
+
(hasattr(self.model_runner.model_config, "enforce_eager")
|
|
436
|
+
and self.model_runner.model_config.enforce_eager)):
|
|
437
|
+
self.model_runner.compilation_manager._precompile_sampling()
|
|
438
|
+
self.model_runner.compilation_manager._precompile_gather_logprobs()
|
|
439
|
+
self.model_runner.initialize_kv_cache(kv_cache_config,
|
|
440
|
+
self.topology_order_id)
|
|
424
441
|
|
|
425
442
|
def get_node_kv_ip_port(self) -> tuple[int, str, int]:
|
|
426
|
-
node_id = get_node_id()
|
|
427
443
|
ip = get_host_ip()
|
|
428
444
|
port = get_kv_transfer_port()
|
|
429
|
-
return (int(
|
|
445
|
+
return (int(self.topology_order_id), ip, int(port))
|
|
430
446
|
|
|
431
447
|
def check_health(self) -> None:
|
|
432
448
|
# worker will always be healthy as long as it's running.
|