tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import time
|
|
2
16
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
3
17
|
|
|
@@ -32,6 +46,8 @@ class CompilationManager:
|
|
|
32
46
|
|
|
33
47
|
def __init__(self, runner: "TPUModelRunner"):
|
|
34
48
|
self.runner = runner
|
|
49
|
+
self._sampling_precompiled = False
|
|
50
|
+
self._gather_logprobs_precompiled = False
|
|
35
51
|
if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
|
|
36
52
|
logger.info("Enabling JAX compile cache.")
|
|
37
53
|
jax.config.update("jax_compilation_cache_dir",
|
|
@@ -86,9 +102,13 @@ class CompilationManager:
|
|
|
86
102
|
return
|
|
87
103
|
self._precompile_select_from_array()
|
|
88
104
|
self._precompile_compute_logits()
|
|
105
|
+
# Skip sampling if already precompiled before KV cache allocation
|
|
106
|
+
if not self._sampling_precompiled:
|
|
107
|
+
self._precompile_sampling()
|
|
89
108
|
self._precompile_disagg_utils()
|
|
90
|
-
|
|
91
|
-
self.
|
|
109
|
+
# Skip gather_logprobs if already precompiled before KV cache allocation
|
|
110
|
+
if not self._gather_logprobs_precompiled:
|
|
111
|
+
self._precompile_gather_logprobs()
|
|
92
112
|
self._precompile_structured_decoding()
|
|
93
113
|
if self.runner.speculative_config:
|
|
94
114
|
self._precompile_speculative_decoding()
|
|
@@ -107,7 +127,7 @@ class CompilationManager:
|
|
|
107
127
|
|
|
108
128
|
self._run_compilation(
|
|
109
129
|
"input_embeddings_merger",
|
|
110
|
-
self.runner.
|
|
130
|
+
self.runner.embed_input_ids_fn,
|
|
111
131
|
self.runner.state,
|
|
112
132
|
dummy_input_ids,
|
|
113
133
|
dummy_multimodal_embeddings,
|
|
@@ -116,7 +136,7 @@ class CompilationManager:
|
|
|
116
136
|
|
|
117
137
|
self._run_compilation(
|
|
118
138
|
"input_embeddings_merger_text_only",
|
|
119
|
-
self.runner.
|
|
139
|
+
self.runner.embed_input_ids_fn,
|
|
120
140
|
self.runner.state,
|
|
121
141
|
dummy_input_ids,
|
|
122
142
|
None,
|
|
@@ -466,43 +486,48 @@ class CompilationManager:
|
|
|
466
486
|
for num_reqs in self.runner.num_reqs_paddings:
|
|
467
487
|
logits_sharding = NamedSharding(
|
|
468
488
|
self.runner.mesh,
|
|
469
|
-
PartitionSpec(ShardingAxisName.
|
|
489
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
490
|
+
ShardingAxisName.MLP_TENSOR))
|
|
470
491
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
471
492
|
sampling_metadata_sharding = NamedSharding(
|
|
472
493
|
self.runner.mesh, PartitionSpec(
|
|
473
|
-
ShardingAxisName.
|
|
494
|
+
ShardingAxisName.MLP_DATA)) if dp_size > 1 else None
|
|
474
495
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
475
496
|
logits_sharding)
|
|
476
497
|
for do_sampling in (True, False):
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
498
|
+
for logprobs in (True, False):
|
|
499
|
+
if do_sampling:
|
|
500
|
+
temperature = np.full((num_reqs, ),
|
|
501
|
+
0.7,
|
|
502
|
+
dtype=np.float32)
|
|
503
|
+
top_k = np.full((num_reqs, ), 20, dtype=np.int32)
|
|
504
|
+
top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
|
|
505
|
+
(temperature, top_k, top_p) = device_array(
|
|
506
|
+
self.runner.mesh, (temperature, top_k, top_p),
|
|
507
|
+
sharding=sampling_metadata_sharding)
|
|
508
|
+
else:
|
|
509
|
+
temperature = None
|
|
510
|
+
top_k = None
|
|
511
|
+
top_p = None
|
|
512
|
+
|
|
513
|
+
sampling_metadata = TPUSupportedSamplingMetadata(
|
|
514
|
+
temperature=temperature,
|
|
515
|
+
top_k=top_k,
|
|
516
|
+
top_p=top_p,
|
|
517
|
+
do_sampling=do_sampling,
|
|
518
|
+
logprobs=logprobs)
|
|
519
|
+
self._run_compilation(
|
|
520
|
+
f"worker{self.runner.rank} sample",
|
|
521
|
+
sample,
|
|
522
|
+
self.runner.rng_params_for_sampling,
|
|
523
|
+
self.runner.mesh,
|
|
524
|
+
logits,
|
|
525
|
+
sampling_metadata,
|
|
526
|
+
num_reqs=num_reqs,
|
|
527
|
+
do_sampling=do_sampling,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
self._sampling_precompiled = True
|
|
506
531
|
|
|
507
532
|
def _precompile_disagg_utils(self) -> None:
|
|
508
533
|
if not is_disagg_enabled():
|
|
@@ -532,8 +557,16 @@ class CompilationManager:
|
|
|
532
557
|
logger.info("Compiling gather_logprobs with different input shapes.")
|
|
533
558
|
hsize = self.runner.model_config.get_vocab_size()
|
|
534
559
|
for num_reqs in self.runner.num_reqs_paddings:
|
|
535
|
-
|
|
536
|
-
|
|
560
|
+
logits_sharding = NamedSharding(
|
|
561
|
+
self.runner.mesh,
|
|
562
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
563
|
+
ShardingAxisName.MLP_TENSOR))
|
|
564
|
+
token_ids_sharding = NamedSharding(
|
|
565
|
+
self.runner.mesh, PartitionSpec(ShardingAxisName.MLP_DATA, ))
|
|
566
|
+
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
567
|
+
logits_sharding)
|
|
568
|
+
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32,
|
|
569
|
+
token_ids_sharding)
|
|
537
570
|
self._run_compilation(
|
|
538
571
|
f"worker{self.runner.rank} gather_logprobs",
|
|
539
572
|
self.runner._compute_and_gather_logprobs,
|
|
@@ -543,6 +576,8 @@ class CompilationManager:
|
|
|
543
576
|
num_reqs=num_reqs,
|
|
544
577
|
)
|
|
545
578
|
|
|
579
|
+
self._gather_logprobs_precompiled = True
|
|
580
|
+
|
|
546
581
|
def _precompile_speculative_decoding(self) -> None:
|
|
547
582
|
logger.info(
|
|
548
583
|
"Compiling speculative_decoding with different input shapes.")
|
tpu_inference/runner/kv_cache.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Any, List
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -7,6 +21,7 @@ from jax._src import dtypes
|
|
|
7
21
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
22
|
from torchax.ops.mappings import t2j_dtype
|
|
9
23
|
|
|
24
|
+
import tpu_inference.kernels.mla.v1.kernel as mla
|
|
10
25
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
|
|
11
26
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
|
|
12
27
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -17,9 +32,13 @@ logger = init_logger(__name__)
|
|
|
17
32
|
DEFAULT_KV_CACHE_DTYPE = jnp.bfloat16
|
|
18
33
|
|
|
19
34
|
|
|
20
|
-
def get_kv_cache_shape_with_mesh(mesh: Mesh,
|
|
21
|
-
|
|
22
|
-
|
|
35
|
+
def get_kv_cache_shape_with_mesh(mesh: Mesh,
|
|
36
|
+
total_num_pages: int,
|
|
37
|
+
page_size: int,
|
|
38
|
+
actual_num_kv_heads: int,
|
|
39
|
+
actual_head_dim: int,
|
|
40
|
+
kv_dtype: any,
|
|
41
|
+
use_mla: bool = False):
|
|
23
42
|
"""Gets the KV cache shape based on the mesh configuration."""
|
|
24
43
|
|
|
25
44
|
model_cnt = mesh.shape["model"]
|
|
@@ -28,15 +47,21 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
|
|
|
28
47
|
# specific model, rather than being determined by the head_dim. If new
|
|
29
48
|
# models are introduced with a head_dim of 64, this will require additional
|
|
30
49
|
# model-specific adjustments.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
50
|
+
if use_mla:
|
|
51
|
+
get_kv_cache_shape_fn = mla.get_kv_cache_shape
|
|
52
|
+
shape = list(
|
|
53
|
+
get_kv_cache_shape_fn(total_num_pages, page_size, actual_head_dim,
|
|
54
|
+
kv_dtype))
|
|
55
|
+
else:
|
|
56
|
+
get_kv_cache_shape_fn = (
|
|
57
|
+
rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
|
|
58
|
+
else rpa.get_kv_cache_shape
|
|
59
|
+
)
|
|
60
|
+
shape = list(
|
|
61
|
+
get_kv_cache_shape_fn(total_num_pages, page_size,
|
|
62
|
+
actual_num_kv_heads // model_cnt,
|
|
63
|
+
actual_head_dim, kv_dtype))
|
|
64
|
+
shape[2] *= model_cnt
|
|
40
65
|
return tuple(shape)
|
|
41
66
|
|
|
42
67
|
|
|
@@ -48,6 +73,7 @@ def create_kv_caches(
|
|
|
48
73
|
mesh: Mesh,
|
|
49
74
|
layer_names: List[str],
|
|
50
75
|
cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
|
|
76
|
+
use_mla: bool = False,
|
|
51
77
|
) -> List[jax.Array]:
|
|
52
78
|
"""
|
|
53
79
|
Creates a list of KV cache where each array mapps to single attention layer.
|
|
@@ -74,12 +100,16 @@ def create_kv_caches(
|
|
|
74
100
|
|
|
75
101
|
cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
|
|
76
102
|
num_kv_heads, head_size,
|
|
77
|
-
cache_dtype)
|
|
103
|
+
cache_dtype, use_mla)
|
|
78
104
|
|
|
79
|
-
|
|
80
|
-
mesh,
|
|
81
|
-
|
|
82
|
-
|
|
105
|
+
if use_mla:
|
|
106
|
+
sharding = NamedSharding(mesh,
|
|
107
|
+
PartitionSpec(ShardingAxisName.MLP_TENSOR))
|
|
108
|
+
else:
|
|
109
|
+
sharding = NamedSharding(
|
|
110
|
+
mesh,
|
|
111
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
|
|
112
|
+
ShardingAxisName.ATTN_HEAD))
|
|
83
113
|
|
|
84
114
|
def _allocate() -> jax.Array:
|
|
85
115
|
return jnp.empty(
|
|
@@ -94,7 +124,8 @@ def create_kv_caches(
|
|
|
94
124
|
return kv_caches
|
|
95
125
|
|
|
96
126
|
|
|
97
|
-
def
|
|
127
|
+
def get_attention_page_size_bytes(mesh: Mesh,
|
|
128
|
+
kv_cache_specs: dict[str, Any]) -> int:
|
|
98
129
|
"""
|
|
99
130
|
Calculate KV cache page size of RPA kernel.
|
|
100
131
|
|
|
@@ -107,14 +138,16 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
|
|
|
107
138
|
"""
|
|
108
139
|
|
|
109
140
|
# Import it here to avoid circular import.
|
|
110
|
-
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
141
|
+
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
|
111
142
|
|
|
112
143
|
page_size_bytes_set = set()
|
|
113
144
|
for kv_cache_spec in kv_cache_specs.values():
|
|
114
145
|
assert isinstance(kv_cache_spec, AttentionSpec)
|
|
115
146
|
|
|
116
147
|
dtype = t2j_dtype(kv_cache_spec.dtype)
|
|
117
|
-
bits = dtypes.bit_width(dtype)
|
|
148
|
+
bits = (dtypes.bit_width(dtype) if hasattr(dtypes, "bit_width") else
|
|
149
|
+
dtypes.itemsize_bits(dtype))
|
|
150
|
+
use_mla = isinstance(kv_cache_spec, MLAAttentionSpec)
|
|
118
151
|
|
|
119
152
|
kv_cache_shape = get_kv_cache_shape_with_mesh(
|
|
120
153
|
mesh=mesh,
|
|
@@ -123,6 +156,7 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
|
|
|
123
156
|
actual_num_kv_heads=kv_cache_spec.num_kv_heads,
|
|
124
157
|
actual_head_dim=kv_cache_spec.head_size,
|
|
125
158
|
kv_dtype=dtype,
|
|
159
|
+
use_mla=use_mla,
|
|
126
160
|
)
|
|
127
161
|
page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
|
|
128
162
|
page_size_bytes_set.add(page_size_bytes)
|
|
@@ -1,5 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
|
-
from typing import TYPE_CHECKING,
|
|
16
|
+
from typing import TYPE_CHECKING, List
|
|
3
17
|
|
|
4
18
|
import jax
|
|
5
19
|
import jax.numpy as jnp
|
|
@@ -39,20 +53,30 @@ class KVCacheManager:
|
|
|
39
53
|
# means this layer will perform attention using the keys and values
|
|
40
54
|
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
|
41
55
|
self.shared_kv_cache_layers: dict[str, str] = {}
|
|
56
|
+
self.use_mla = self.runner.model_config.use_mla
|
|
42
57
|
|
|
43
58
|
def get_kv_cache_spec(self):
|
|
44
59
|
# TODO(xiang): this hack tricks engine core to init successfully
|
|
45
60
|
block_size = self.runner.cache_config.block_size
|
|
46
|
-
use_mla = self.runner.model_config.use_mla
|
|
47
61
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
48
62
|
|
|
49
63
|
# If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
|
|
50
64
|
# attention into compilation config.
|
|
51
65
|
# Use FullAttentionSpec for each layer
|
|
52
66
|
# TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
|
|
67
|
+
model_config = self.runner.model_config
|
|
68
|
+
if self.use_mla:
|
|
69
|
+
# Individually pad the RopE and latents
|
|
70
|
+
qk_rope_head_dim = getattr(model_config.hf_text_config,
|
|
71
|
+
"qk_rope_head_dim", 0)
|
|
72
|
+
padded_kv_lora_rank = common_utils.align_to(
|
|
73
|
+
model_config.hf_text_config.kv_lora_rank, 128)
|
|
74
|
+
padded_qk_rope_head_dim = common_utils.align_to(
|
|
75
|
+
qk_rope_head_dim, 128)
|
|
76
|
+
mla_head_size = padded_kv_lora_rank + padded_qk_rope_head_dim
|
|
77
|
+
|
|
53
78
|
if len(self.runner.vllm_config.compilation_config.
|
|
54
79
|
static_forward_context) == 0:
|
|
55
|
-
model_config = self.runner.model_config
|
|
56
80
|
parallel_config = self.runner.parallel_config
|
|
57
81
|
# Pad num_kv_heads to multiple of TP size.
|
|
58
82
|
num_kv_heads = common_utils.get_padded_num_heads(
|
|
@@ -61,11 +85,11 @@ class KVCacheManager:
|
|
|
61
85
|
head_size = common_utils.get_padded_head_dim(
|
|
62
86
|
model_config.get_head_size())
|
|
63
87
|
for i in range(model_config.get_num_layers(parallel_config)):
|
|
64
|
-
if use_mla:
|
|
88
|
+
if self.use_mla:
|
|
65
89
|
kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
|
|
66
90
|
block_size=block_size,
|
|
67
|
-
num_kv_heads=
|
|
68
|
-
head_size=
|
|
91
|
+
num_kv_heads=1,
|
|
92
|
+
head_size=mla_head_size,
|
|
69
93
|
dtype=self.runner.kv_cache_dtype,
|
|
70
94
|
cache_dtype_str=self.runner.vllm_config.cache_config.
|
|
71
95
|
cache_dtype)
|
|
@@ -83,14 +107,13 @@ class KVCacheManager:
|
|
|
83
107
|
self.runner.mesh.shape["model"])
|
|
84
108
|
head_size = common_utils.get_padded_head_dim(
|
|
85
109
|
hf_config.hidden_size // hf_config.num_attention_heads)
|
|
86
|
-
|
|
87
110
|
# Eagle3 has only 1 layer
|
|
88
111
|
for i in range(1):
|
|
89
|
-
if use_mla:
|
|
90
|
-
kv_cache_spec[f"
|
|
112
|
+
if self.use_mla:
|
|
113
|
+
kv_cache_spec[f"draft_layer.{i}"] = MLAAttentionSpec(
|
|
91
114
|
block_size=block_size,
|
|
92
|
-
num_kv_heads=
|
|
93
|
-
head_size=
|
|
115
|
+
num_kv_heads=1,
|
|
116
|
+
head_size=mla_head_size,
|
|
94
117
|
dtype=self.runner.kv_cache_dtype,
|
|
95
118
|
cache_dtype_str=self.runner.vllm_config.
|
|
96
119
|
cache_config.cache_dtype)
|
|
@@ -104,6 +127,7 @@ class KVCacheManager:
|
|
|
104
127
|
# Else propagate attention modules from compilation config.
|
|
105
128
|
layers = get_layers_from_vllm_config(self.runner.vllm_config,
|
|
106
129
|
Attention)
|
|
130
|
+
logger.warning(f"Compilation num_layers = {len(layers.items())}")
|
|
107
131
|
for layer_name, attn_module in layers.items():
|
|
108
132
|
if (kv_tgt_layer :=
|
|
109
133
|
attn_module.kv_sharing_target_layer_name) is not None:
|
|
@@ -127,11 +151,11 @@ class KVCacheManager:
|
|
|
127
151
|
attn_module.head_size),
|
|
128
152
|
dtype=self.runner.kv_cache_dtype,
|
|
129
153
|
sliding_window=attn_module.sliding_window)
|
|
130
|
-
elif use_mla:
|
|
131
|
-
kv_cache_spec[
|
|
154
|
+
elif self.use_mla:
|
|
155
|
+
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
|
132
156
|
block_size=block_size,
|
|
133
|
-
num_kv_heads=
|
|
134
|
-
head_size=
|
|
157
|
+
num_kv_heads=1,
|
|
158
|
+
head_size=mla_head_size,
|
|
135
159
|
dtype=self.runner.kv_cache_dtype,
|
|
136
160
|
cache_dtype_str=self.runner.vllm_config.
|
|
137
161
|
cache_config.cache_dtype)
|
|
@@ -188,7 +212,6 @@ class KVCacheManager:
|
|
|
188
212
|
# uniform page size.
|
|
189
213
|
representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
|
190
214
|
page_size_bytes = representative_spec.page_size_bytes
|
|
191
|
-
self.runner.layer_name_to_kvcache_index: Dict[str, int] = {}
|
|
192
215
|
kv_caches = self.runner.kv_caches
|
|
193
216
|
num_blocks_list = []
|
|
194
217
|
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
|
|
@@ -198,14 +221,20 @@ class KVCacheManager:
|
|
|
198
221
|
# num_blocks must be a multiple of dp_size
|
|
199
222
|
num_blocks = (num_blocks // dp_size) * dp_size
|
|
200
223
|
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
224
|
+
if self.use_mla:
|
|
225
|
+
head_size = self.runner.model_config.hf_config.kv_lora_rank + \
|
|
226
|
+
self.runner.model_config.hf_config.qk_rope_head_dim
|
|
227
|
+
else:
|
|
228
|
+
head_size = representative_spec.head_size
|
|
201
229
|
kv_cache = create_kv_caches(
|
|
202
230
|
num_blocks=num_blocks,
|
|
203
231
|
block_size=representative_spec.block_size,
|
|
204
232
|
num_kv_heads=representative_spec.num_kv_heads,
|
|
205
|
-
head_size=
|
|
233
|
+
head_size=head_size,
|
|
206
234
|
mesh=self.runner.mesh,
|
|
207
235
|
layer_names=[f'kv_cache_tensor.{i}'],
|
|
208
236
|
cache_dtype=t2j_dtype(representative_spec.dtype),
|
|
237
|
+
use_mla=self.use_mla,
|
|
209
238
|
)[0]
|
|
210
239
|
kv_caches.append(kv_cache)
|
|
211
240
|
num_blocks_list.append(num_blocks)
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from __future__ import annotations
|
|
2
16
|
|
|
3
17
|
from typing import TYPE_CHECKING
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import TYPE_CHECKING
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -134,7 +148,7 @@ class MultiModalManager:
|
|
|
134
148
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
|
135
149
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
|
136
150
|
# depending on the input multimodal items.
|
|
137
|
-
curr_group_outputs = self.runner.
|
|
151
|
+
curr_group_outputs = self.runner.embed_multimodal_fn(
|
|
138
152
|
self.runner.state, image_grid_thw, **batched_mm_inputs)
|
|
139
153
|
|
|
140
154
|
sanity_check_mm_encoder_outputs(
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Dict
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from __future__ import annotations
|
|
2
16
|
|
|
3
17
|
from dataclasses import dataclass
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
from typing import TYPE_CHECKING, Tuple
|
|
3
17
|
|