tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import math
|
|
2
16
|
from dataclasses import InitVar, dataclass
|
|
3
17
|
from typing import Any, Tuple
|
|
@@ -6,7 +20,6 @@ import jax
|
|
|
6
20
|
import jax.numpy as jnp
|
|
7
21
|
from flax import nnx
|
|
8
22
|
from flax.typing import Sharding
|
|
9
|
-
from jax.experimental import shard_map
|
|
10
23
|
from jax.sharding import Mesh
|
|
11
24
|
from jax.sharding import PartitionSpec as P
|
|
12
25
|
|
|
@@ -17,6 +30,7 @@ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
|
|
|
17
30
|
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
|
|
18
31
|
get_tuned_block_sizes
|
|
19
32
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
33
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
20
34
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
21
35
|
from tpu_inference.layers.jax.base import create_param
|
|
22
36
|
from tpu_inference.layers.jax.layers import RMSNorm
|
|
@@ -52,8 +66,9 @@ class MLA(nnx.Module):
|
|
|
52
66
|
rms_norm_eps: float
|
|
53
67
|
|
|
54
68
|
# Sharding attributes
|
|
55
|
-
|
|
69
|
+
rd_sharding: Sharding = ()
|
|
56
70
|
q_da_sharding: Sharding = ()
|
|
71
|
+
ap_sharding: Sharding = ()
|
|
57
72
|
anh_sharding: Sharding = ()
|
|
58
73
|
kv_da_sharding: Sharding = ()
|
|
59
74
|
|
|
@@ -113,10 +128,10 @@ class MLA(nnx.Module):
|
|
|
113
128
|
self.q_da_sharding,
|
|
114
129
|
self.dtype,
|
|
115
130
|
random_init=self.random_init)
|
|
116
|
-
self.
|
|
131
|
+
self.kernel_q_up_proj_AP = create_param(
|
|
117
132
|
rngs,
|
|
118
|
-
(self.q_lora_rank, self.N
|
|
119
|
-
self.
|
|
133
|
+
(self.q_lora_rank, self.N * self.qk_head_dim),
|
|
134
|
+
self.ap_sharding,
|
|
120
135
|
self.dtype,
|
|
121
136
|
random_init=self.random_init,
|
|
122
137
|
)
|
|
@@ -127,6 +142,10 @@ class MLA(nnx.Module):
|
|
|
127
142
|
self.dtype,
|
|
128
143
|
random_init=self.random_init,
|
|
129
144
|
)
|
|
145
|
+
# NOTE (jacobplatin): we are keeping these variables as 3D because
|
|
146
|
+
# we would need to reshape them before the below projection,
|
|
147
|
+
# which caused issues as Qwix wasn't quantizing it correctly
|
|
148
|
+
# on the abstract pass
|
|
130
149
|
if self.use_mla_kernel:
|
|
131
150
|
self.kernel_k_up_proj_ANH = create_param(
|
|
132
151
|
rngs,
|
|
@@ -143,17 +162,18 @@ class MLA(nnx.Module):
|
|
|
143
162
|
random_init=self.random_init,
|
|
144
163
|
)
|
|
145
164
|
else:
|
|
146
|
-
self.
|
|
165
|
+
self.kernel_kv_up_proj_AL = create_param(
|
|
147
166
|
rngs,
|
|
148
|
-
(self.kv_lora_rank, self.N
|
|
149
|
-
self.qk_nope_head_dim + self.v_head_dim),
|
|
150
|
-
self.
|
|
167
|
+
(self.kv_lora_rank, self.N *
|
|
168
|
+
(self.qk_nope_head_dim + self.v_head_dim)),
|
|
169
|
+
self.
|
|
170
|
+
ap_sharding, # NOTE: we use the same sharding for kv_up_proj_AL and kernel_q_up_proj_AP
|
|
151
171
|
self.dtype,
|
|
152
172
|
random_init=self.random_init,
|
|
153
173
|
)
|
|
154
|
-
self.
|
|
155
|
-
rngs, (self.N
|
|
156
|
-
self.
|
|
174
|
+
self.kernel_o_proj_RD = create_param(
|
|
175
|
+
rngs, (self.N * self.v_head_dim, self.D),
|
|
176
|
+
self.rd_sharding,
|
|
157
177
|
self.dtype,
|
|
158
178
|
random_init=self.random_init)
|
|
159
179
|
self.q_rms_norm = RMSNorm(
|
|
@@ -209,9 +229,10 @@ class MLA(nnx.Module):
|
|
|
209
229
|
q_TA = jnp.einsum("TD,DA -> TA", x_q_TD,
|
|
210
230
|
self.kernel_q_down_proj_DA.value)
|
|
211
231
|
q_TA = self.q_rms_norm(q_TA)
|
|
212
|
-
# Query up projection.
|
|
213
|
-
|
|
214
|
-
|
|
232
|
+
# Query up projection, then reshape to TNH.
|
|
233
|
+
q_TP = jnp.einsum("TA,AP -> TP", q_TA,
|
|
234
|
+
self.kernel_q_up_proj_AP.value)
|
|
235
|
+
q_TNH = q_TP.reshape(q_TA.shape[0], self.N, self.qk_head_dim)
|
|
215
236
|
# Split the query into nope and rope.
|
|
216
237
|
q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
|
|
217
238
|
q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
|
|
@@ -247,9 +268,12 @@ class MLA(nnx.Module):
|
|
|
247
268
|
k_rope_SNH = jnp.broadcast_to(
|
|
248
269
|
k_rope_SNH,
|
|
249
270
|
(k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
|
|
250
|
-
# KV up projection.
|
|
251
|
-
|
|
252
|
-
|
|
271
|
+
# KV up projection, then reshape to SN(Hk+Hv).
|
|
272
|
+
kv_SL = jnp.einsum("SA,AL -> SL", kv_SA,
|
|
273
|
+
self.kernel_kv_up_proj_AL.value)
|
|
274
|
+
kv_nope_SNH = kv_SL.reshape(
|
|
275
|
+
kv_SA.shape[0], self.N,
|
|
276
|
+
self.qk_nope_head_dim + self.v_head_dim)
|
|
253
277
|
# Split the latent kv vector into k nope vector and v vector.
|
|
254
278
|
k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
|
|
255
279
|
v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
|
|
@@ -287,9 +311,8 @@ class MLA(nnx.Module):
|
|
|
287
311
|
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
288
312
|
k_scale = self._k_scale
|
|
289
313
|
v_scale = self._v_scale
|
|
290
|
-
k_SNH, v_SNH =
|
|
291
|
-
|
|
292
|
-
v_scale)
|
|
314
|
+
k_SNH, v_SNH = quantize_kv(self.kv_cache_quantized_dtype,
|
|
315
|
+
k_SNH, v_SNH, k_scale, v_scale)
|
|
293
316
|
|
|
294
317
|
new_kv_cache, outputs_TNH = self.attention(
|
|
295
318
|
is_prefill,
|
|
@@ -323,8 +346,10 @@ class MLA(nnx.Module):
|
|
|
323
346
|
with jax.named_scope("o_proj"):
|
|
324
347
|
outputs_TNH = nnx.with_sharding_constraint(
|
|
325
348
|
outputs_TNH, self.activation_attention_out_td)
|
|
326
|
-
|
|
327
|
-
|
|
349
|
+
outputs_TR = outputs_TNH.reshape(outputs_TNH.shape[0],
|
|
350
|
+
self.N * self.v_head_dim)
|
|
351
|
+
o_TD = jnp.einsum("TR,RD -> TD", outputs_TR,
|
|
352
|
+
self.kernel_o_proj_RD.value)
|
|
328
353
|
|
|
329
354
|
return new_kv_cache, o_TD
|
|
330
355
|
|
|
@@ -391,12 +416,12 @@ class MLA(nnx.Module):
|
|
|
391
416
|
return outputs
|
|
392
417
|
|
|
393
418
|
output_TNH, kv_cache = jax.jit(
|
|
394
|
-
|
|
419
|
+
jax.shard_map(
|
|
395
420
|
_ragged_paged_attention,
|
|
396
421
|
mesh=mesh,
|
|
397
422
|
in_specs=in_specs,
|
|
398
423
|
out_specs=out_specs,
|
|
399
|
-
|
|
424
|
+
check_vma=False,
|
|
400
425
|
))(
|
|
401
426
|
q_TNH,
|
|
402
427
|
k_SKH,
|
|
@@ -502,12 +527,12 @@ class MLA(nnx.Module):
|
|
|
502
527
|
return kv_cache, output
|
|
503
528
|
|
|
504
529
|
kv_cache, output_TNH = jax.jit(
|
|
505
|
-
|
|
530
|
+
jax.shard_map(
|
|
506
531
|
_mla_ragged_paged_attention,
|
|
507
532
|
mesh=mesh,
|
|
508
533
|
in_specs=in_specs,
|
|
509
534
|
out_specs=out_specs,
|
|
510
|
-
|
|
535
|
+
check_vma=False,
|
|
511
536
|
), )(
|
|
512
537
|
q_TNA,
|
|
513
538
|
q_rope_TNH,
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import InitVar, dataclass
|
|
2
16
|
from typing import Tuple
|
|
3
17
|
|
|
@@ -5,7 +19,6 @@ import jax
|
|
|
5
19
|
import jax.numpy as jnp
|
|
6
20
|
from flax import nnx
|
|
7
21
|
from flax.typing import Sharding
|
|
8
|
-
from jax.experimental import shard_map
|
|
9
22
|
from jax.sharding import Mesh
|
|
10
23
|
from jax.sharding import PartitionSpec as P
|
|
11
24
|
|
|
@@ -13,6 +26,7 @@ from tpu_inference import utils
|
|
|
13
26
|
from tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 import \
|
|
14
27
|
ragged_paged_attention_hd64
|
|
15
28
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
29
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
16
30
|
from tpu_inference.layers.jax.base import create_param
|
|
17
31
|
from tpu_inference.layers.jax.rope import GptOssRotaryEmbedding
|
|
18
32
|
|
|
@@ -185,12 +199,12 @@ class GptOssAttention(nnx.Module):
|
|
|
185
199
|
)
|
|
186
200
|
|
|
187
201
|
output_TNH, kv_cache = jax.jit(
|
|
188
|
-
|
|
202
|
+
jax.shard_map(
|
|
189
203
|
_ragged_paged_attention_wrapper,
|
|
190
204
|
mesh=mesh,
|
|
191
205
|
in_specs=in_specs,
|
|
192
206
|
out_specs=out_specs,
|
|
193
|
-
|
|
207
|
+
check_vma=False,
|
|
194
208
|
))(
|
|
195
209
|
q_TNH,
|
|
196
210
|
k_SKH,
|
|
@@ -235,9 +249,8 @@ class GptOssAttention(nnx.Module):
|
|
|
235
249
|
# q_scale = self._q_scale
|
|
236
250
|
k_scale = self._k_scale
|
|
237
251
|
v_scale = self._v_scale
|
|
238
|
-
k_TKH, v_TKH =
|
|
239
|
-
|
|
240
|
-
k_scale, v_scale)
|
|
252
|
+
k_TKH, v_TKH = quantize_kv(self.kv_cache_quantized_dtype, k_TKH,
|
|
253
|
+
v_TKH, k_scale, v_scale)
|
|
241
254
|
|
|
242
255
|
with jax.named_scope("attn_op"):
|
|
243
256
|
new_kv_cache, attn_out_TNH = self.attention(
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import dataclass
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -5,8 +19,8 @@ import jax.numpy as jnp
|
|
|
5
19
|
from flax import nnx
|
|
6
20
|
from jax.sharding import Sharding
|
|
7
21
|
|
|
8
|
-
from tpu_inference import utils
|
|
9
22
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
23
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
10
24
|
from tpu_inference.layers.jax.attention.attention import Attention, KVCache
|
|
11
25
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
12
26
|
from tpu_inference.logger import init_logger
|
|
@@ -114,9 +128,8 @@ class Llama4Attention(Attention):
|
|
|
114
128
|
# q_scale = self._q_scale
|
|
115
129
|
k_scale = self._k_scale
|
|
116
130
|
v_scale = self._v_scale
|
|
117
|
-
k_SKH, v_SKH =
|
|
118
|
-
|
|
119
|
-
k_scale, v_scale)
|
|
131
|
+
k_SKH, v_SKH = quantize_kv(self.kv_cache_quantized_dtype, k_SKH,
|
|
132
|
+
v_SKH, k_scale, v_scale)
|
|
120
133
|
|
|
121
134
|
with jax.named_scope("attn_op"):
|
|
122
135
|
new_kv_cache, outputs_TNH = self.attention(
|
tpu_inference/layers/jax/base.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import dataclasses
|
|
2
16
|
from dataclasses import dataclass, fields
|
|
3
17
|
from typing import Any, Callable, Mapping
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""
|
|
2
15
|
Current Used Abbreviation for Tensor Dimensions:
|
|
3
16
|
B: Batch size
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import InitVar, dataclass
|
|
2
16
|
from typing import Any
|
|
3
17
|
|
tpu_inference/layers/jax/misc.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import math
|
|
2
16
|
from typing import Tuple
|
|
3
17
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import enum
|
|
2
16
|
from dataclasses import InitVar, dataclass
|
|
3
17
|
from functools import partial
|
|
@@ -14,8 +28,8 @@ from qwix._src.providers import ptq
|
|
|
14
28
|
|
|
15
29
|
from tpu_inference.layers.jax.base import create_param
|
|
16
30
|
from tpu_inference.layers.jax.layers import FlaxUtils
|
|
17
|
-
from tpu_inference.layers.jax.moe.moe import MoE
|
|
18
|
-
from tpu_inference.models.jax.utils.
|
|
31
|
+
from tpu_inference.layers.jax.moe.moe import CombineExperts, MoE
|
|
32
|
+
from tpu_inference.models.jax.utils.qwix.qwix_utils import (
|
|
19
33
|
manually_quantize_qwix_activation, manually_quantize_qwix_weight)
|
|
20
34
|
|
|
21
35
|
modeling_flax_utils = FlaxUtils()
|
|
@@ -150,6 +164,7 @@ class SparseMoE(MoE):
|
|
|
150
164
|
|
|
151
165
|
def __post_init__(self, rngs: nnx.Rngs):
|
|
152
166
|
super().__post_init__(rngs)
|
|
167
|
+
self.combine_experts = CombineExperts(dtype=self.dtype)
|
|
153
168
|
|
|
154
169
|
# Derive the expert sharding
|
|
155
170
|
self.expert_axis_name = self.edf_sharding[0]
|
|
@@ -331,15 +346,7 @@ class SparseMoE(MoE):
|
|
|
331
346
|
processed_tokens, jnp.argsort(sort_indices))
|
|
332
347
|
reshaped_tokens_TXD = unsorted_tokens_tD.reshape(
|
|
333
348
|
-1, self.num_experts_per_tok, self.hidden_size)
|
|
334
|
-
|
|
335
|
-
output_TD = jnp.einsum(
|
|
336
|
-
"TXD,TX -> TD",
|
|
337
|
-
reshaped_tokens_TXD.astype(jnp.float32),
|
|
338
|
-
router_weights_TX.astype(jnp.float32),
|
|
339
|
-
precision='float32',
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
return output_TD.astype(self.dtype)
|
|
349
|
+
return self.combine_experts(reshaped_tokens_TXD, router_weights_TX)
|
|
343
350
|
|
|
344
351
|
def _gmm(self, inputs, kernel, group_sizes):
|
|
345
352
|
"""Performs Grouped Matrix Multiply."""
|
|
@@ -575,11 +582,11 @@ class SparseMoE(MoE):
|
|
|
575
582
|
)
|
|
576
583
|
out_specs = PartitionSpec(*self.activation_ffw_td)
|
|
577
584
|
|
|
578
|
-
mapped_moe_fwd = partial(jax.
|
|
585
|
+
mapped_moe_fwd = partial(jax.shard_map,
|
|
579
586
|
mesh=self.mesh,
|
|
580
587
|
in_specs=in_specs,
|
|
581
588
|
out_specs=out_specs,
|
|
582
|
-
|
|
589
|
+
check_vma=False)(
|
|
583
590
|
SparseMoE._distributed_sparse_moe_fwd)
|
|
584
591
|
|
|
585
592
|
kernel_gating_EDF = self.kernel_gating_EDF.value
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import InitVar, dataclass
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import InitVar, dataclass
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -12,6 +26,29 @@ from tpu_inference.layers.jax.layers import FlaxUtils
|
|
|
12
26
|
modeling_flax_utils = FlaxUtils()
|
|
13
27
|
|
|
14
28
|
|
|
29
|
+
@dataclass(kw_only=True)
|
|
30
|
+
class CombineExperts(nnx.Module):
|
|
31
|
+
"""Combines expert outputs with router weights.
|
|
32
|
+
|
|
33
|
+
Supports `TED,TE -> TD` when passed expert outputs, using float32
|
|
34
|
+
accumulation for numerical stability, then casting back to the target
|
|
35
|
+
dtype.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
dtype: jnp.dtype
|
|
39
|
+
|
|
40
|
+
def __call__(self, expert_outputs_TED: Float, weights_TE: Float) -> Float:
|
|
41
|
+
with jax.named_scope("combine_experts"):
|
|
42
|
+
output_TD = jnp.einsum(
|
|
43
|
+
"TED,TE -> TD",
|
|
44
|
+
expert_outputs_TED.astype(jnp.float32),
|
|
45
|
+
weights_TE.astype(jnp.float32),
|
|
46
|
+
precision="float32",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
return output_TD.astype(self.dtype)
|
|
50
|
+
|
|
51
|
+
|
|
15
52
|
@dataclass(kw_only=True)
|
|
16
53
|
class Router(nnx.Module):
|
|
17
54
|
"""Router module for Mixture-of-Experts (MoE) layers.
|
|
@@ -139,6 +176,9 @@ class MoE(nnx.Module):
|
|
|
139
176
|
sharding=self.efd_sharding,
|
|
140
177
|
random_init=self.random_init)
|
|
141
178
|
|
|
179
|
+
# Shared combine module for combine path
|
|
180
|
+
self.combine_experts = CombineExperts(dtype=self.dtype)
|
|
181
|
+
|
|
142
182
|
def _moe_fwd_preapply_router_weights(self, x_TD: jax.Array, weights_TE):
|
|
143
183
|
"""Performs the forward pass of the MoE experts with router weights pre-applied to the inputs.
|
|
144
184
|
|
|
@@ -204,6 +244,6 @@ class MoE(nnx.Module):
|
|
|
204
244
|
with jax.named_scope("down_projection"):
|
|
205
245
|
down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
|
|
206
246
|
self.kernel_down_proj_EFD.value)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
return output_TD
|
|
247
|
+
# Combine across experts
|
|
248
|
+
output_TD = self.combine_experts(down_proj_TED, weights)
|
|
249
|
+
return output_TD
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import List, Protocol
|
|
16
|
+
|
|
17
|
+
from flax import nnx
|
|
18
|
+
from vllm.distributed import get_pp_group
|
|
19
|
+
from vllm.distributed.utils import get_pp_indices
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PPMissingLayer(nnx.Module):
|
|
23
|
+
"""
|
|
24
|
+
A placeholder layer for missing layers in a pipeline parallel model.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, *args, **kwargs):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
def __call__(self, *args, **kwargs):
|
|
31
|
+
"""Return the first arg from args or the first value from kwargs."""
|
|
32
|
+
return args[0] if args else next(iter(kwargs.values()))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class LayerFn(Protocol):
|
|
36
|
+
|
|
37
|
+
def __call__(self) -> nnx.Module:
|
|
38
|
+
...
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def make_layers(
|
|
42
|
+
num_hidden_layers: int,
|
|
43
|
+
layer_fn: LayerFn,
|
|
44
|
+
) -> tuple[int, int, List[nnx.Module]]:
|
|
45
|
+
start_layer, end_layer = get_pp_indices(num_hidden_layers,
|
|
46
|
+
get_pp_group().rank_in_group,
|
|
47
|
+
get_pp_group().world_size)
|
|
48
|
+
|
|
49
|
+
layers = [PPMissingLayer() for _ in range(start_layer)] \
|
|
50
|
+
+ [layer_fn() for _ in range(start_layer, end_layer)] \
|
|
51
|
+
+ [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
|
|
52
|
+
|
|
53
|
+
return start_layer, end_layer, layers
|
tpu_inference/layers/jax/rope.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import math
|
|
2
16
|
from dataclasses import dataclass, field
|
|
3
17
|
from typing import Optional, Tuple
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import math
|
|
2
16
|
from typing import Any, Dict
|
|
3
17
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""
|
|
2
15
|
JAX-based rejection sampler for speculative decoding on TPU.
|
|
3
16
|
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -28,7 +42,7 @@ def sample(
|
|
|
28
42
|
if tpu_sampling_metadata.do_sampling:
|
|
29
43
|
# Unshard the logits explicity to avoid latency increase.
|
|
30
44
|
logits = jax.lax.with_sharding_constraint(
|
|
31
|
-
logits, NamedSharding(mesh, P(ShardingAxisName.
|
|
45
|
+
logits, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
|
|
32
46
|
greedy_sampled = jnp.argmax(logits, axis=-1)
|
|
33
47
|
if not tpu_sampling_metadata.do_sampling:
|
|
34
48
|
return greedy_sampled
|