tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import re
|
|
2
16
|
from dataclasses import dataclass
|
|
3
17
|
from typing import List, Optional, Tuple
|
|
@@ -11,6 +25,9 @@ from jax.sharding import Mesh, NamedSharding
|
|
|
11
25
|
from jax.sharding import PartitionSpec as P
|
|
12
26
|
from vllm.config import VllmConfig
|
|
13
27
|
|
|
28
|
+
from tpu_inference.layers.common.quant_methods import MXFP4
|
|
29
|
+
from tpu_inference.layers.common.quantization import (
|
|
30
|
+
dequantize_tensor_from_mxfp4_packed, e8m0_to_fp32, u8_unpack_e2m1)
|
|
14
31
|
from tpu_inference.layers.jax.attention.gpt_oss_attention import (
|
|
15
32
|
AttentionMetadata, GptOssAttention)
|
|
16
33
|
from tpu_inference.layers.jax.constants import KVCacheType
|
|
@@ -18,8 +35,6 @@ from tpu_inference.layers.jax.layers import Embedder, LMhead, RMSNorm
|
|
|
18
35
|
from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
|
|
19
36
|
from tpu_inference.layers.jax.transformer_block import TransformerBlock
|
|
20
37
|
from tpu_inference.logger import init_logger
|
|
21
|
-
from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
|
|
22
|
-
MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16, unpack_mxfp4_to_fp32)
|
|
23
38
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
24
39
|
get_param, model_weights_generator, print_param_info)
|
|
25
40
|
|
|
@@ -102,9 +117,9 @@ class GptOss(nnx.Module):
|
|
|
102
117
|
rope_ntk_beta=rope_ntk_beta,
|
|
103
118
|
rngs=self.rng,
|
|
104
119
|
random_init=self.random_init,
|
|
105
|
-
query_tnh=P(
|
|
106
|
-
keyvalue_skh=P(
|
|
107
|
-
attn_o_tnh=P(
|
|
120
|
+
query_tnh=P("data", 'model', None),
|
|
121
|
+
keyvalue_skh=P("data", 'model', None),
|
|
122
|
+
attn_o_tnh=P("data", 'model', None),
|
|
108
123
|
dnh_sharding=P(None, 'model', None),
|
|
109
124
|
dkh_sharding=P(None, 'model', None),
|
|
110
125
|
nhd_sharding=P('model', None, None),
|
|
@@ -205,7 +220,7 @@ class GptOss(nnx.Module):
|
|
|
205
220
|
|
|
206
221
|
# MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
|
|
207
222
|
swap_mlp_transform = transforms[
|
|
208
|
-
"swap_last2"] if quant_method ==
|
|
223
|
+
"swap_last2"] if quant_method == MXFP4 else None
|
|
209
224
|
|
|
210
225
|
mappings = {
|
|
211
226
|
# Embeddings, Norms, and LM Head
|
|
@@ -285,7 +300,7 @@ class GptOss(nnx.Module):
|
|
|
285
300
|
# Build a pool of weights with MXFP4 experts combined if neededs
|
|
286
301
|
pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool(
|
|
287
302
|
names_and_weights_generator,
|
|
288
|
-
mappings) if quant_method ==
|
|
303
|
+
mappings) if quant_method == MXFP4 else {
|
|
289
304
|
loaded_name: loaded_weight
|
|
290
305
|
for loaded_name, loaded_weight in names_and_weights_generator
|
|
291
306
|
})
|
|
@@ -316,8 +331,9 @@ class GptOss(nnx.Module):
|
|
|
316
331
|
blocks_u8, scales_u8 = loaded_weight
|
|
317
332
|
# Quantized param (QArray): set qvalue/scale directly and skip regular path
|
|
318
333
|
if hasattr(model_weight, "array"): # QArray check
|
|
319
|
-
codes_fp32_t
|
|
320
|
-
|
|
334
|
+
codes_fp32_t = u8_unpack_e2m1(blocks_u8).astype(
|
|
335
|
+
jnp.float32)
|
|
336
|
+
scales_fp32_t = e8m0_to_fp32(scales_u8)
|
|
321
337
|
self._load_mxfp4(
|
|
322
338
|
model_weight=model_weight,
|
|
323
339
|
codes_fp32_t=codes_fp32_t,
|
|
@@ -328,7 +344,7 @@ class GptOss(nnx.Module):
|
|
|
328
344
|
print_param_info(model_weight, loaded_name)
|
|
329
345
|
continue
|
|
330
346
|
# Not a QArray: dequantize MXFP4 to BF16 full weights
|
|
331
|
-
prepared_weight =
|
|
347
|
+
prepared_weight = dequantize_tensor_from_mxfp4_packed(
|
|
332
348
|
blocks_u8, scales_u8)
|
|
333
349
|
|
|
334
350
|
# Single regular-tensor load call (BF16 or dequantized MXFP4)
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import dataclass
|
|
2
16
|
from typing import TYPE_CHECKING, Any, Dict, Union
|
|
3
17
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from itertools import islice
|
|
1
16
|
from typing import List, Optional, Tuple
|
|
2
17
|
|
|
3
18
|
import jax
|
|
@@ -8,13 +23,19 @@ from transformers import LlamaConfig, modeling_flax_utils
|
|
|
8
23
|
from vllm.config import VllmConfig
|
|
9
24
|
|
|
10
25
|
from tpu_inference import utils
|
|
26
|
+
from tpu_inference.distributed.jax_parallel_state import get_pp_group
|
|
11
27
|
from tpu_inference.layers.common.attention_interface import attention
|
|
12
28
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
29
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
13
30
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
31
|
+
from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers
|
|
14
32
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
15
33
|
from tpu_inference.logger import init_logger
|
|
34
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
35
|
+
JaxIntermediateTensors
|
|
16
36
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
17
37
|
load_hf_weights)
|
|
38
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
18
39
|
|
|
19
40
|
logger = init_logger(__name__)
|
|
20
41
|
|
|
@@ -79,7 +100,8 @@ class LlamaAttention(nnx.Module):
|
|
|
79
100
|
self.hidden_size // self.num_heads)
|
|
80
101
|
self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
|
|
81
102
|
|
|
82
|
-
sharding_size = mesh
|
|
103
|
+
sharding_size = get_mesh_shape_product(mesh,
|
|
104
|
+
ShardingAxisName.MLP_TENSOR)
|
|
83
105
|
self.num_heads = utils.get_padded_num_heads(self.num_heads,
|
|
84
106
|
sharding_size)
|
|
85
107
|
self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
|
|
@@ -152,8 +174,8 @@ class LlamaAttention(nnx.Module):
|
|
|
152
174
|
# q_scale = self._q_scale
|
|
153
175
|
k_scale = self._k_scale
|
|
154
176
|
v_scale = self._v_scale
|
|
155
|
-
k, v =
|
|
156
|
-
|
|
177
|
+
k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
|
|
178
|
+
v_scale)
|
|
157
179
|
new_kv_cache, outputs = attention(
|
|
158
180
|
kv_cache,
|
|
159
181
|
q,
|
|
@@ -235,38 +257,52 @@ class LlamaModel(nnx.Module):
|
|
|
235
257
|
rms_norm_eps = hf_config.rms_norm_eps
|
|
236
258
|
hidden_size = hf_config.hidden_size
|
|
237
259
|
|
|
238
|
-
self.
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
260
|
+
self.is_first_rank = get_pp_group().is_first_rank
|
|
261
|
+
self.is_last_rank = get_pp_group().is_last_rank
|
|
262
|
+
|
|
263
|
+
if self.is_first_rank or (hf_config.tie_word_embeddings
|
|
264
|
+
and self.is_last_rank):
|
|
265
|
+
self.embed = nnx.Embed(
|
|
266
|
+
num_embeddings=vocab_size,
|
|
267
|
+
features=hidden_size,
|
|
268
|
+
param_dtype=dtype,
|
|
269
|
+
embedding_init=nnx.with_partitioning(
|
|
270
|
+
init_fn, (ShardingAxisName.VOCAB, None)),
|
|
271
|
+
rngs=rng,
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
self.embed = PPMissingLayer()
|
|
275
|
+
|
|
276
|
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
277
|
+
hf_config.num_hidden_layers,
|
|
278
|
+
lambda: LlamaDecoderLayer(
|
|
248
279
|
config=hf_config,
|
|
249
280
|
dtype=dtype,
|
|
250
281
|
rng=rng,
|
|
251
282
|
mesh=mesh,
|
|
252
283
|
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
253
|
-
kv_cache_dtype=vllm_config.cache_config.cache_dtype)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
rngs=rng,
|
|
262
|
-
)
|
|
263
|
-
if model_config.hf_config.tie_word_embeddings:
|
|
264
|
-
self.lm_head = self.embed.embedding
|
|
265
|
-
else:
|
|
266
|
-
self.lm_head = nnx.Param(
|
|
267
|
-
init_fn(rng.params(), (hidden_size, vocab_size), dtype),
|
|
268
|
-
sharding=(None, ShardingAxisName.VOCAB),
|
|
284
|
+
kv_cache_dtype=vllm_config.cache_config.cache_dtype))
|
|
285
|
+
if self.is_last_rank:
|
|
286
|
+
self.norm = nnx.RMSNorm(
|
|
287
|
+
hidden_size,
|
|
288
|
+
epsilon=rms_norm_eps,
|
|
289
|
+
param_dtype=dtype,
|
|
290
|
+
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
291
|
+
rngs=rng,
|
|
269
292
|
)
|
|
293
|
+
else:
|
|
294
|
+
self.norm = PPMissingLayer()
|
|
295
|
+
|
|
296
|
+
if self.is_last_rank:
|
|
297
|
+
if model_config.hf_config.tie_word_embeddings:
|
|
298
|
+
self.lm_head = self.embed.embedding
|
|
299
|
+
else:
|
|
300
|
+
self.lm_head = nnx.Param(
|
|
301
|
+
init_fn(rng.params(), (hidden_size, vocab_size), dtype),
|
|
302
|
+
sharding=(None, ShardingAxisName.VOCAB),
|
|
303
|
+
)
|
|
304
|
+
else:
|
|
305
|
+
self.lm_head = PPMissingLayer()
|
|
270
306
|
|
|
271
307
|
self.aux_hidden_state_layers = []
|
|
272
308
|
if vllm_config.speculative_config and vllm_config.speculative_config.method == "eagle3":
|
|
@@ -282,10 +318,18 @@ class LlamaModel(nnx.Module):
|
|
|
282
318
|
kv_caches: List[jax.Array],
|
|
283
319
|
input_ids: jax.Array,
|
|
284
320
|
attention_metadata: AttentionMetadata,
|
|
285
|
-
|
|
286
|
-
|
|
321
|
+
intermediate_tensors: JaxIntermediateTensors | None,
|
|
322
|
+
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
|
|
323
|
+
List[jax.Array], JaxIntermediateTensors]:
|
|
324
|
+
if self.is_first_rank:
|
|
325
|
+
x = self.embed(input_ids)
|
|
326
|
+
else:
|
|
327
|
+
assert intermediate_tensors is not None
|
|
328
|
+
x = intermediate_tensors["hidden_states"]
|
|
329
|
+
|
|
287
330
|
aux_hidden_states = []
|
|
288
|
-
for i, layer in enumerate(
|
|
331
|
+
for i, layer in enumerate(
|
|
332
|
+
islice(self.layers, self.start_layer, self.end_layer)):
|
|
289
333
|
if i in self.aux_hidden_state_layers:
|
|
290
334
|
aux_hidden_states.append(x)
|
|
291
335
|
kv_cache = kv_caches[i]
|
|
@@ -295,6 +339,10 @@ class LlamaModel(nnx.Module):
|
|
|
295
339
|
attention_metadata,
|
|
296
340
|
)
|
|
297
341
|
kv_caches[i] = kv_cache
|
|
342
|
+
if not self.is_last_rank:
|
|
343
|
+
# Note: add aux_hidden_states to make the output spec consistent.
|
|
344
|
+
return kv_caches, JaxIntermediateTensors({"hidden_states":
|
|
345
|
+
x}), aux_hidden_states
|
|
298
346
|
x = self.norm(x)
|
|
299
347
|
return kv_caches, x, aux_hidden_states
|
|
300
348
|
|
|
@@ -313,19 +361,33 @@ class LlamaForCausalLM(nnx.Module):
|
|
|
313
361
|
mesh=mesh,
|
|
314
362
|
)
|
|
315
363
|
|
|
364
|
+
self.pp_missing_layers = []
|
|
365
|
+
for path, module in nnx.iter_graph(self.model):
|
|
366
|
+
if isinstance(module, PPMissingLayer):
|
|
367
|
+
# the path should be sth like ('layers', '0')
|
|
368
|
+
self.pp_missing_layers.append('.'.join([str(s) for s in path]))
|
|
369
|
+
|
|
316
370
|
def __call__(
|
|
317
371
|
self,
|
|
318
372
|
kv_caches: List[jax.Array],
|
|
319
373
|
input_ids: jax.Array,
|
|
320
374
|
attention_metadata: AttentionMetadata,
|
|
375
|
+
_input_embeds=None,
|
|
376
|
+
_input_positions=None,
|
|
377
|
+
_layer_name_to_kv_cache=None,
|
|
378
|
+
_lora_metadata=None,
|
|
379
|
+
intermediate_tensors: JaxIntermediateTensors | None = None,
|
|
380
|
+
_is_first_rank: bool | None = None,
|
|
381
|
+
_is_last_rank: bool | None = None,
|
|
321
382
|
*args,
|
|
322
|
-
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]
|
|
323
|
-
|
|
383
|
+
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
|
|
384
|
+
List[jax.Array], JaxIntermediateTensors]:
|
|
385
|
+
return self.model(
|
|
324
386
|
kv_caches,
|
|
325
387
|
input_ids,
|
|
326
388
|
attention_metadata,
|
|
389
|
+
intermediate_tensors,
|
|
327
390
|
)
|
|
328
|
-
return kv_caches, x, aux_hidden_states
|
|
329
391
|
|
|
330
392
|
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
331
393
|
if self.vllm_config.model_config.hf_config.tie_word_embeddings:
|
|
@@ -373,4 +435,5 @@ class LlamaForCausalLM(nnx.Module):
|
|
|
373
435
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
374
436
|
model=self,
|
|
375
437
|
metadata_map=metadata_map,
|
|
376
|
-
mesh=self.mesh
|
|
438
|
+
mesh=self.mesh,
|
|
439
|
+
pp_missing_layers=self.pp_missing_layers)
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import re
|
|
2
16
|
from typing import List, Optional, Tuple
|
|
3
17
|
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import List, Tuple
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -304,6 +318,8 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
304
318
|
"fc": "model.fc.kernel",
|
|
305
319
|
"lm_head": "lm_head.kernel",
|
|
306
320
|
"d2t": "draft_id_to_target_id",
|
|
321
|
+
"embed_tokens":
|
|
322
|
+
"model.embed_tokens.embedding", # Some checkpoints need this
|
|
307
323
|
}
|
|
308
324
|
|
|
309
325
|
# Define keys to keep in original dtype (e.g., float32 for stability)
|
|
@@ -311,8 +327,6 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
311
327
|
r".*d2t.*",
|
|
312
328
|
]
|
|
313
329
|
|
|
314
|
-
# `embed_tokens` is shared between target and draft.
|
|
315
|
-
exclude_regex = [r".*embed_tokens.*"]
|
|
316
330
|
metadata_map = get_default_maps(
|
|
317
331
|
self.vllm_config.speculative_config.draft_model_config, self.mesh,
|
|
318
332
|
mappings)
|
|
@@ -325,10 +339,9 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
325
339
|
metadata_map=metadata_map,
|
|
326
340
|
mesh=self.mesh,
|
|
327
341
|
is_draft_model=True,
|
|
328
|
-
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex
|
|
329
|
-
exclude_regex=exclude_regex if exclude_regex else None)
|
|
342
|
+
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
|
|
330
343
|
|
|
331
|
-
# If the embedding is not initialized, initialize it with a
|
|
344
|
+
# If the embedding is not initialized, initialize it with a dummy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
|
|
332
345
|
if isinstance(self.model.embed_tokens.embedding.value,
|
|
333
346
|
jax.ShapeDtypeStruct):
|
|
334
347
|
self.model.embed_tokens.embedding.value = jnp.zeros(
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import re
|
|
2
16
|
from typing import Any, List, Optional, Tuple
|
|
3
17
|
|
|
@@ -242,7 +256,7 @@ class LlamaGuard4ForCausalLM(nnx.Module):
|
|
|
242
256
|
self.lm_head.input_embedding_table_DV.value)
|
|
243
257
|
return logits_TV
|
|
244
258
|
|
|
245
|
-
def
|
|
259
|
+
def embed_input_ids(
|
|
246
260
|
self,
|
|
247
261
|
input_ids: jax.Array,
|
|
248
262
|
multimodal_embeddings: Optional[List[jax.Array]] = None
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import List, Optional, Tuple
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -10,6 +24,7 @@ from vllm.config import VllmConfig
|
|
|
10
24
|
from tpu_inference import utils
|
|
11
25
|
from tpu_inference.layers.common.attention_interface import attention
|
|
12
26
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
27
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
13
28
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
29
|
from tpu_inference.logger import init_logger
|
|
15
30
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
@@ -152,8 +167,8 @@ class Qwen2Attention(nnx.Module):
|
|
|
152
167
|
# q_scale = self._q_scale
|
|
153
168
|
k_scale = self._k_scale
|
|
154
169
|
v_scale = self._v_scale
|
|
155
|
-
k, v =
|
|
156
|
-
|
|
170
|
+
k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
|
|
171
|
+
v_scale)
|
|
157
172
|
new_kv_cache, outputs = attention(
|
|
158
173
|
kv_cache,
|
|
159
174
|
q,
|