tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import functools
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
import jax
|
|
20
|
+
import jax.numpy as jnp
|
|
21
|
+
import torch
|
|
22
|
+
from jax.sharding import Mesh
|
|
23
|
+
|
|
24
|
+
from tpu_inference.runner.input_batch import InputBatch
|
|
25
|
+
from tpu_inference.utils import device_array
|
|
26
|
+
|
|
27
|
+
DEFAULT_SAMPLING_PARAMS = dict(
|
|
28
|
+
temperature=-1.0,
|
|
29
|
+
top_k=0,
|
|
30
|
+
top_p=1.0,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@functools.partial(
|
|
35
|
+
jax.tree_util.register_dataclass,
|
|
36
|
+
data_fields=[
|
|
37
|
+
"temperature",
|
|
38
|
+
"top_k",
|
|
39
|
+
"top_p",
|
|
40
|
+
],
|
|
41
|
+
meta_fields=["do_sampling", "logprobs"],
|
|
42
|
+
)
|
|
43
|
+
@dataclass
|
|
44
|
+
class TPUSupportedSamplingMetadata:
|
|
45
|
+
temperature: Optional[jnp.ndarray] = None
|
|
46
|
+
top_k: Optional[jnp.ndarray] = None
|
|
47
|
+
top_p: Optional[jnp.ndarray] = None
|
|
48
|
+
do_sampling: bool = False
|
|
49
|
+
logprobs: bool = False
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def from_input_batch(
|
|
53
|
+
cls,
|
|
54
|
+
mesh: Mesh,
|
|
55
|
+
input_batch: InputBatch,
|
|
56
|
+
padded_num_reqs: int,
|
|
57
|
+
sharding: Optional[jax.sharding.Sharding] = None,
|
|
58
|
+
) -> "TPUSupportedSamplingMetadata":
|
|
59
|
+
needs_logprobs = input_batch.max_num_logprobs > 0 if input_batch.max_num_logprobs else False
|
|
60
|
+
if input_batch.all_greedy:
|
|
61
|
+
return cls(do_sampling=False, logprobs=needs_logprobs)
|
|
62
|
+
num_reqs = input_batch.num_reqs
|
|
63
|
+
|
|
64
|
+
def fill_slice(cpu_torch_tensor: torch.Tensor,
|
|
65
|
+
fill_val: float) -> torch.Tensor:
|
|
66
|
+
# Pad value is the default one.
|
|
67
|
+
cpu_torch_tensor[num_reqs:padded_num_reqs] = fill_val
|
|
68
|
+
return cpu_torch_tensor
|
|
69
|
+
|
|
70
|
+
temp_tensor = fill_slice(input_batch.temperature_cpu,
|
|
71
|
+
DEFAULT_SAMPLING_PARAMS["temperature"])
|
|
72
|
+
top_k_tensor = fill_slice(input_batch.top_k_cpu,
|
|
73
|
+
DEFAULT_SAMPLING_PARAMS["top_k"])
|
|
74
|
+
top_p_tensor = fill_slice(input_batch.top_p_cpu,
|
|
75
|
+
DEFAULT_SAMPLING_PARAMS["top_p"])
|
|
76
|
+
|
|
77
|
+
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
|
78
|
+
return cls(
|
|
79
|
+
temperature=device_array(mesh,
|
|
80
|
+
temp_tensor[:padded_num_reqs],
|
|
81
|
+
sharding=sharding),
|
|
82
|
+
top_p=device_array(mesh,
|
|
83
|
+
top_p_tensor[:padded_num_reqs],
|
|
84
|
+
sharding=sharding),
|
|
85
|
+
top_k=device_array(mesh,
|
|
86
|
+
top_k_tensor[:padded_num_reqs],
|
|
87
|
+
sharding=sharding),
|
|
88
|
+
do_sampling=not input_batch.all_greedy,
|
|
89
|
+
logprobs=needs_logprobs,
|
|
90
|
+
)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import Any, Optional, Tuple
|
|
17
|
+
|
|
18
|
+
# Flax and JAX sharding imports
|
|
19
|
+
import jax
|
|
20
|
+
from flax import nnx
|
|
21
|
+
|
|
22
|
+
from tpu_inference.layers.jax.attention.attention import (AttentionMetadata,
|
|
23
|
+
KVCache)
|
|
24
|
+
from tpu_inference.layers.jax.layers import DenseFFW
|
|
25
|
+
from tpu_inference.layers.jax.moe.moe import MoE
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(kw_only=True)
|
|
29
|
+
class TransformerBlock(nnx.Module):
|
|
30
|
+
"""
|
|
31
|
+
A heavy weight module which serves as the stateful live blocks in serving
|
|
32
|
+
|
|
33
|
+
custom_module can be either a dense module (i.e., DenseFFW) or MoE.
|
|
34
|
+
"""
|
|
35
|
+
pre_attention_norm: nnx.Module
|
|
36
|
+
pre_mlp_norm: nnx.Module
|
|
37
|
+
custom_module: Optional[nnx.Module] = None
|
|
38
|
+
attn: nnx.Module
|
|
39
|
+
use_attention_rope: bool = True
|
|
40
|
+
quant: Any | None = None
|
|
41
|
+
|
|
42
|
+
def __call__(
|
|
43
|
+
self, x_TD: jax.Array, is_prefill: bool, kv_cache: KVCache,
|
|
44
|
+
attention_metadata: AttentionMetadata
|
|
45
|
+
) -> Tuple[KVCache, jax.Array]:
|
|
46
|
+
# Attn Block
|
|
47
|
+
attn_residual_TD = x_TD
|
|
48
|
+
x_TD = self.pre_attention_norm(x_TD)
|
|
49
|
+
new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
|
|
50
|
+
attention_metadata,
|
|
51
|
+
self.use_attention_rope)
|
|
52
|
+
attn_output_TD += attn_residual_TD
|
|
53
|
+
|
|
54
|
+
# FFW Block
|
|
55
|
+
ffw_residual_TD = attn_output_TD
|
|
56
|
+
normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
|
|
57
|
+
logits_TD = self.custom_module(normed_ffw_input_TD)
|
|
58
|
+
logits_TD += ffw_residual_TD
|
|
59
|
+
return new_cache, logits_TD
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass(kw_only=True)
|
|
63
|
+
class SharedExpertsTransformerBlock(TransformerBlock):
|
|
64
|
+
"""Create a modified TransformerBlock that sums MoE layer output with shared expert output.
|
|
65
|
+
|
|
66
|
+
Users can provide the FFW layer in two ways:
|
|
67
|
+
1. Pass the module (either `MoE` or `DenseFFW`) to the `custom_module`
|
|
68
|
+
attribute.
|
|
69
|
+
2. Specify the `moe_ffw` or `dense_ffw` attributes
|
|
70
|
+
(e.g., for passing quantized modules).
|
|
71
|
+
|
|
72
|
+
Attributes:
|
|
73
|
+
moe_ffw: Optional MoE layer.
|
|
74
|
+
dense_ffw: Optional DFF layer.
|
|
75
|
+
shared_experts: Optional shared experts module, used if MoE is enabled.
|
|
76
|
+
|
|
77
|
+
If an `MoE` layer is used (from either path), its output is summed
|
|
78
|
+
with the `shared_experts` module.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
moe_ffw: Optional[MoE] = None
|
|
82
|
+
dense_ffw: Optional[DenseFFW] = None
|
|
83
|
+
shared_experts: Optional[DenseFFW] = None
|
|
84
|
+
|
|
85
|
+
def __call__(self, x_TD, is_prefill, kv_cache, attention_metadata):
|
|
86
|
+
# Attn Block
|
|
87
|
+
attn_residual_TD = x_TD
|
|
88
|
+
x_TD = self.pre_attention_norm(x_TD)
|
|
89
|
+
new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
|
|
90
|
+
attention_metadata,
|
|
91
|
+
self.use_attention_rope)
|
|
92
|
+
attn_output_TD += attn_residual_TD
|
|
93
|
+
|
|
94
|
+
# FFW Block
|
|
95
|
+
ffw_residual_TD = attn_output_TD
|
|
96
|
+
normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
|
|
97
|
+
|
|
98
|
+
if isinstance(self.custom_module, MoE):
|
|
99
|
+
moe_layer = self.custom_module
|
|
100
|
+
else:
|
|
101
|
+
moe_layer = self.moe_ffw
|
|
102
|
+
|
|
103
|
+
if isinstance(self.custom_module, DenseFFW):
|
|
104
|
+
dense_layer = self.custom_module
|
|
105
|
+
else:
|
|
106
|
+
dense_layer = self.dense_ffw
|
|
107
|
+
|
|
108
|
+
if moe_layer is not None:
|
|
109
|
+
logits_TD = moe_layer(normed_ffw_input_TD)
|
|
110
|
+
# Add the shared expert outputs to the MoE outputs.
|
|
111
|
+
shared_expert_output_TD = self.shared_experts(normed_ffw_input_TD)
|
|
112
|
+
logits_TD += shared_expert_output_TD
|
|
113
|
+
elif dense_layer is not None:
|
|
114
|
+
logits_TD = dense_layer(normed_ffw_input_TD)
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
"Neither custom_module, moe_ffw nor dense_ffw attribute is set for this SharedExpertsTransformerBlock!"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
logits_TD += ffw_residual_TD
|
|
121
|
+
return new_cache, logits_TD
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import torch
|
|
9
|
+
from jax.sharding import Mesh
|
|
10
|
+
from torchax.interop import jax_view, torch_view
|
|
11
|
+
from torchax.ops.mappings import t2j
|
|
12
|
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
13
|
+
AttentionLayer, AttentionType)
|
|
14
|
+
|
|
15
|
+
from tpu_inference import utils
|
|
16
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
17
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
18
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
19
|
+
from tpu_inference.logger import init_logger
|
|
20
|
+
from tpu_inference.models.vllm.vllm_model_wrapper_context import \
|
|
21
|
+
get_vllm_model_wrapper_context
|
|
22
|
+
|
|
23
|
+
logger = init_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PallasAttentionBackend(AttentionBackend):
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def get_name() -> str:
|
|
30
|
+
return "PALLAS"
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
|
34
|
+
return PallasAttentionBackendImpl
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class PallasAttentionBackendImpl(AttentionImpl):
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
num_heads: int,
|
|
42
|
+
head_size: int,
|
|
43
|
+
scale: float,
|
|
44
|
+
num_kv_heads: int,
|
|
45
|
+
alibi_slopes: list[float] | None,
|
|
46
|
+
sliding_window: int | None,
|
|
47
|
+
kv_cache_dtype: str,
|
|
48
|
+
logits_soft_cap: float | None = None,
|
|
49
|
+
attn_type: AttentionType = AttentionType.DECODER,
|
|
50
|
+
kv_sharing_target_layer_name: str | None = None,
|
|
51
|
+
sinks: torch.Tensor | None = None,
|
|
52
|
+
) -> None:
|
|
53
|
+
self.num_heads = num_heads
|
|
54
|
+
self.head_size = head_size
|
|
55
|
+
self.scale = float(scale)
|
|
56
|
+
self.num_kv_heads = num_kv_heads
|
|
57
|
+
self.sliding_window = sliding_window
|
|
58
|
+
self.logits_soft_cap = logits_soft_cap
|
|
59
|
+
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
|
60
|
+
|
|
61
|
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
62
|
+
if alibi_slopes is not None:
|
|
63
|
+
raise NotImplementedError("Alibi slopes is not supported.")
|
|
64
|
+
self.kv_cache_quantized_dtype = None
|
|
65
|
+
if kv_cache_dtype != "auto":
|
|
66
|
+
self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
67
|
+
kv_cache_dtype)
|
|
68
|
+
|
|
69
|
+
if attn_type != AttentionType.DECODER:
|
|
70
|
+
raise NotImplementedError("Encoder self-attention and "
|
|
71
|
+
"encoder/decoder cross-attention "
|
|
72
|
+
"are not implemented for "
|
|
73
|
+
"PallasAttentionBackendImpl")
|
|
74
|
+
|
|
75
|
+
self.sinks = sinks
|
|
76
|
+
if self.sinks is not None:
|
|
77
|
+
assert self.sinks.shape[0] == num_heads, (
|
|
78
|
+
"Sinks must have the same number of heads as the number of "
|
|
79
|
+
"heads in the layer")
|
|
80
|
+
|
|
81
|
+
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
82
|
+
#TODO (kyuyeunk): Shard the sinks along num_heads dim
|
|
83
|
+
if self.sinks is not None:
|
|
84
|
+
sinks = t2j(self.sinks, use_dlpack=False)
|
|
85
|
+
sinks = torch_view(sinks.astype(jnp.float32))
|
|
86
|
+
self.sinks = torch.nn.Parameter(sinks, requires_grad=False)
|
|
87
|
+
|
|
88
|
+
def forward(
|
|
89
|
+
self,
|
|
90
|
+
layer: AttentionLayer,
|
|
91
|
+
query: torch.Tensor,
|
|
92
|
+
key: torch.Tensor,
|
|
93
|
+
value: torch.Tensor,
|
|
94
|
+
kv_cache: torch.Tensor,
|
|
95
|
+
attn_metadata: AttentionMetadata,
|
|
96
|
+
output: Optional[torch.Tensor] = None,
|
|
97
|
+
output_scale: Optional[torch.Tensor] = None,
|
|
98
|
+
) -> torch.Tensor:
|
|
99
|
+
if output_scale is not None:
|
|
100
|
+
raise NotImplementedError(
|
|
101
|
+
"fused output quantization is not yet supported for "
|
|
102
|
+
"PallasAttentionBackendImpl")
|
|
103
|
+
|
|
104
|
+
if kv_cache.numel():
|
|
105
|
+
raise RuntimeError(
|
|
106
|
+
"KV cache from vLLM Attention layer should be empty but has "
|
|
107
|
+
"the size of %s.", kv_cache.numel())
|
|
108
|
+
|
|
109
|
+
del kv_cache # Use kv_cache from vllm wrapper context values instead.
|
|
110
|
+
|
|
111
|
+
vllm_model_wrapper_context = get_vllm_model_wrapper_context()
|
|
112
|
+
kv_cache_index = vllm_model_wrapper_context.layer_name_to_kvcache_index[
|
|
113
|
+
layer.layer_name]
|
|
114
|
+
kv_cache = vllm_model_wrapper_context.kv_caches[kv_cache_index]
|
|
115
|
+
|
|
116
|
+
mesh = vllm_model_wrapper_context.mesh
|
|
117
|
+
|
|
118
|
+
query, key, value = jax_view(query), jax_view(key), jax_view(value)
|
|
119
|
+
q_scale = k_scale = v_scale = None
|
|
120
|
+
if self.kv_cache_quantized_dtype:
|
|
121
|
+
key, value = quantize_kv(self.kv_cache_quantized_dtype, key, value,
|
|
122
|
+
layer._k_scale_float,
|
|
123
|
+
layer._v_scale_float)
|
|
124
|
+
# TODO(kyuyeunk): Enable w8a8 when VREG spill issue is resolved.
|
|
125
|
+
# q_scale = layer._q_scale_float
|
|
126
|
+
k_scale = layer._k_scale_float
|
|
127
|
+
v_scale = layer._v_scale_float
|
|
128
|
+
|
|
129
|
+
sinks = jax_view(self.sinks)
|
|
130
|
+
|
|
131
|
+
new_kv_cache, outputs = _jax_attn_func(
|
|
132
|
+
kv_cache,
|
|
133
|
+
query,
|
|
134
|
+
key,
|
|
135
|
+
value,
|
|
136
|
+
sinks,
|
|
137
|
+
attn_metadata,
|
|
138
|
+
mesh,
|
|
139
|
+
self.scale,
|
|
140
|
+
self.head_size,
|
|
141
|
+
self.num_heads,
|
|
142
|
+
self.num_kv_heads,
|
|
143
|
+
q_scale,
|
|
144
|
+
k_scale,
|
|
145
|
+
v_scale,
|
|
146
|
+
self.sliding_window,
|
|
147
|
+
)
|
|
148
|
+
vllm_model_wrapper_context.kv_caches[kv_cache_index] = new_kv_cache
|
|
149
|
+
|
|
150
|
+
return torch_view(outputs)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@functools.partial(
|
|
154
|
+
jax.jit,
|
|
155
|
+
static_argnames=(
|
|
156
|
+
"mesh",
|
|
157
|
+
"scale",
|
|
158
|
+
"head_size",
|
|
159
|
+
"num_heads",
|
|
160
|
+
"num_kv_heads",
|
|
161
|
+
"q_scale",
|
|
162
|
+
"k_scale",
|
|
163
|
+
"v_scale",
|
|
164
|
+
"sliding_window",
|
|
165
|
+
),
|
|
166
|
+
donate_argnames=("kv_cache"),
|
|
167
|
+
)
|
|
168
|
+
def _jax_attn_func(
|
|
169
|
+
kv_cache: jax.Array,
|
|
170
|
+
q: jax.Array,
|
|
171
|
+
k: jax.Array,
|
|
172
|
+
v: jax.Array,
|
|
173
|
+
sinks: jax.Array | None,
|
|
174
|
+
attention_metadata: AttentionMetadata,
|
|
175
|
+
mesh: Mesh,
|
|
176
|
+
scale: float,
|
|
177
|
+
head_size: int,
|
|
178
|
+
num_heads: int,
|
|
179
|
+
num_kv_heads: int,
|
|
180
|
+
q_scale: float | None = None,
|
|
181
|
+
k_scale: float | None = None,
|
|
182
|
+
v_scale: float | None = None,
|
|
183
|
+
sliding_window: int | None = None,
|
|
184
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
185
|
+
del scale # Unused for now, as the attention function applies a default scale.
|
|
186
|
+
|
|
187
|
+
# Get shapes from vllm
|
|
188
|
+
q_len, q_compute_dim = q.shape
|
|
189
|
+
k_len, k_compute_dim = k.shape
|
|
190
|
+
assert k.shape == v.shape
|
|
191
|
+
assert q_compute_dim == head_size * num_heads
|
|
192
|
+
assert k_compute_dim == head_size * num_kv_heads
|
|
193
|
+
|
|
194
|
+
# Convert the shapes from vLLM's convetion to what the attention function expects
|
|
195
|
+
# bs, num_heads, q_len, head_size
|
|
196
|
+
q = q.reshape(q_len, num_heads, head_size)
|
|
197
|
+
# bs, num_kv_heads, k_len, head_size
|
|
198
|
+
k = k.reshape(k_len, num_kv_heads, head_size)
|
|
199
|
+
v = v.reshape(k_len, num_kv_heads, head_size)
|
|
200
|
+
|
|
201
|
+
new_kv_cache, outputs = attention(
|
|
202
|
+
kv_cache,
|
|
203
|
+
q,
|
|
204
|
+
k,
|
|
205
|
+
v,
|
|
206
|
+
attention_metadata,
|
|
207
|
+
mesh,
|
|
208
|
+
q_scale=q_scale,
|
|
209
|
+
k_scale=k_scale,
|
|
210
|
+
v_scale=v_scale,
|
|
211
|
+
sinks=sinks,
|
|
212
|
+
attention_chunk_size=sliding_window,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Convert the shape back to vLLM's convention
|
|
216
|
+
assert outputs.shape[0] == q_len
|
|
217
|
+
assert outputs.shape[1] == num_heads
|
|
218
|
+
assert outputs.shape[2] == head_size
|
|
219
|
+
outputs = outputs.reshape(q_len, q_compute_dim)
|
|
220
|
+
|
|
221
|
+
return new_kv_cache, outputs
|