tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import re
|
|
2
16
|
from dataclasses import dataclass
|
|
3
17
|
from typing import List, Optional, Tuple
|
|
@@ -11,6 +25,9 @@ from jax.sharding import Mesh, NamedSharding
|
|
|
11
25
|
from jax.sharding import PartitionSpec as P
|
|
12
26
|
from vllm.config import VllmConfig
|
|
13
27
|
|
|
28
|
+
from tpu_inference.layers.common.quant_methods import MXFP4
|
|
29
|
+
from tpu_inference.layers.common.quantization import (
|
|
30
|
+
dequantize_tensor_from_mxfp4_packed, e8m0_to_fp32, u8_unpack_e2m1)
|
|
14
31
|
from tpu_inference.layers.jax.attention.gpt_oss_attention import (
|
|
15
32
|
AttentionMetadata, GptOssAttention)
|
|
16
33
|
from tpu_inference.layers.jax.constants import KVCacheType
|
|
@@ -18,8 +35,6 @@ from tpu_inference.layers.jax.layers import Embedder, LMhead, RMSNorm
|
|
|
18
35
|
from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
|
|
19
36
|
from tpu_inference.layers.jax.transformer_block import TransformerBlock
|
|
20
37
|
from tpu_inference.logger import init_logger
|
|
21
|
-
from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
|
|
22
|
-
MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16, unpack_mxfp4_to_fp32)
|
|
23
38
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
24
39
|
get_param, model_weights_generator, print_param_info)
|
|
25
40
|
|
|
@@ -102,9 +117,9 @@ class GptOss(nnx.Module):
|
|
|
102
117
|
rope_ntk_beta=rope_ntk_beta,
|
|
103
118
|
rngs=self.rng,
|
|
104
119
|
random_init=self.random_init,
|
|
105
|
-
query_tnh=P(
|
|
106
|
-
keyvalue_skh=P(
|
|
107
|
-
attn_o_tnh=P(
|
|
120
|
+
query_tnh=P("data", 'model', None),
|
|
121
|
+
keyvalue_skh=P("data", 'model', None),
|
|
122
|
+
attn_o_tnh=P("data", 'model', None),
|
|
108
123
|
dnh_sharding=P(None, 'model', None),
|
|
109
124
|
dkh_sharding=P(None, 'model', None),
|
|
110
125
|
nhd_sharding=P('model', None, None),
|
|
@@ -205,7 +220,7 @@ class GptOss(nnx.Module):
|
|
|
205
220
|
|
|
206
221
|
# MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
|
|
207
222
|
swap_mlp_transform = transforms[
|
|
208
|
-
"swap_last2"] if quant_method ==
|
|
223
|
+
"swap_last2"] if quant_method == MXFP4 else None
|
|
209
224
|
|
|
210
225
|
mappings = {
|
|
211
226
|
# Embeddings, Norms, and LM Head
|
|
@@ -285,7 +300,7 @@ class GptOss(nnx.Module):
|
|
|
285
300
|
# Build a pool of weights with MXFP4 experts combined if neededs
|
|
286
301
|
pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool(
|
|
287
302
|
names_and_weights_generator,
|
|
288
|
-
mappings) if quant_method ==
|
|
303
|
+
mappings) if quant_method == MXFP4 else {
|
|
289
304
|
loaded_name: loaded_weight
|
|
290
305
|
for loaded_name, loaded_weight in names_and_weights_generator
|
|
291
306
|
})
|
|
@@ -316,8 +331,9 @@ class GptOss(nnx.Module):
|
|
|
316
331
|
blocks_u8, scales_u8 = loaded_weight
|
|
317
332
|
# Quantized param (QArray): set qvalue/scale directly and skip regular path
|
|
318
333
|
if hasattr(model_weight, "array"): # QArray check
|
|
319
|
-
codes_fp32_t
|
|
320
|
-
|
|
334
|
+
codes_fp32_t = u8_unpack_e2m1(blocks_u8).astype(
|
|
335
|
+
jnp.float32)
|
|
336
|
+
scales_fp32_t = e8m0_to_fp32(scales_u8)
|
|
321
337
|
self._load_mxfp4(
|
|
322
338
|
model_weight=model_weight,
|
|
323
339
|
codes_fp32_t=codes_fp32_t,
|
|
@@ -328,7 +344,7 @@ class GptOss(nnx.Module):
|
|
|
328
344
|
print_param_info(model_weight, loaded_name)
|
|
329
345
|
continue
|
|
330
346
|
# Not a QArray: dequantize MXFP4 to BF16 full weights
|
|
331
|
-
prepared_weight =
|
|
347
|
+
prepared_weight = dequantize_tensor_from_mxfp4_packed(
|
|
332
348
|
blocks_u8, scales_u8)
|
|
333
349
|
|
|
334
350
|
# Single regular-tensor load call (BF16 or dequantized MXFP4)
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from dataclasses import dataclass
|
|
2
16
|
from typing import TYPE_CHECKING, Any, Dict, Union
|
|
3
17
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from itertools import islice
|
|
1
16
|
from typing import List, Optional, Tuple
|
|
2
17
|
|
|
3
18
|
import jax
|
|
@@ -8,13 +23,19 @@ from transformers import LlamaConfig, modeling_flax_utils
|
|
|
8
23
|
from vllm.config import VllmConfig
|
|
9
24
|
|
|
10
25
|
from tpu_inference import utils
|
|
26
|
+
from tpu_inference.distributed.jax_parallel_state import get_pp_group
|
|
11
27
|
from tpu_inference.layers.common.attention_interface import attention
|
|
12
28
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
29
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
13
30
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
31
|
+
from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers
|
|
14
32
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
15
33
|
from tpu_inference.logger import init_logger
|
|
34
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
35
|
+
JaxIntermediateTensors
|
|
16
36
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
17
37
|
load_hf_weights)
|
|
38
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
18
39
|
|
|
19
40
|
logger = init_logger(__name__)
|
|
20
41
|
|
|
@@ -79,7 +100,8 @@ class LlamaAttention(nnx.Module):
|
|
|
79
100
|
self.hidden_size // self.num_heads)
|
|
80
101
|
self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
|
|
81
102
|
|
|
82
|
-
sharding_size = mesh
|
|
103
|
+
sharding_size = get_mesh_shape_product(mesh,
|
|
104
|
+
ShardingAxisName.MLP_TENSOR)
|
|
83
105
|
self.num_heads = utils.get_padded_num_heads(self.num_heads,
|
|
84
106
|
sharding_size)
|
|
85
107
|
self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
|
|
@@ -152,8 +174,8 @@ class LlamaAttention(nnx.Module):
|
|
|
152
174
|
# q_scale = self._q_scale
|
|
153
175
|
k_scale = self._k_scale
|
|
154
176
|
v_scale = self._v_scale
|
|
155
|
-
k, v =
|
|
156
|
-
|
|
177
|
+
k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
|
|
178
|
+
v_scale)
|
|
157
179
|
new_kv_cache, outputs = attention(
|
|
158
180
|
kv_cache,
|
|
159
181
|
q,
|
|
@@ -235,38 +257,52 @@ class LlamaModel(nnx.Module):
|
|
|
235
257
|
rms_norm_eps = hf_config.rms_norm_eps
|
|
236
258
|
hidden_size = hf_config.hidden_size
|
|
237
259
|
|
|
238
|
-
self.
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
260
|
+
self.is_first_rank = get_pp_group().is_first_rank
|
|
261
|
+
self.is_last_rank = get_pp_group().is_last_rank
|
|
262
|
+
|
|
263
|
+
if self.is_first_rank or (hf_config.tie_word_embeddings
|
|
264
|
+
and self.is_last_rank):
|
|
265
|
+
self.embed = nnx.Embed(
|
|
266
|
+
num_embeddings=vocab_size,
|
|
267
|
+
features=hidden_size,
|
|
268
|
+
param_dtype=dtype,
|
|
269
|
+
embedding_init=nnx.with_partitioning(
|
|
270
|
+
init_fn, (ShardingAxisName.VOCAB, None)),
|
|
271
|
+
rngs=rng,
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
self.embed = PPMissingLayer()
|
|
275
|
+
|
|
276
|
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
277
|
+
hf_config.num_hidden_layers,
|
|
278
|
+
lambda: LlamaDecoderLayer(
|
|
248
279
|
config=hf_config,
|
|
249
280
|
dtype=dtype,
|
|
250
281
|
rng=rng,
|
|
251
282
|
mesh=mesh,
|
|
252
283
|
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
253
|
-
kv_cache_dtype=vllm_config.cache_config.cache_dtype)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
rngs=rng,
|
|
262
|
-
)
|
|
263
|
-
if model_config.hf_config.tie_word_embeddings:
|
|
264
|
-
self.lm_head = self.embed.embedding
|
|
265
|
-
else:
|
|
266
|
-
self.lm_head = nnx.Param(
|
|
267
|
-
init_fn(rng.params(), (hidden_size, vocab_size), dtype),
|
|
268
|
-
sharding=(None, ShardingAxisName.VOCAB),
|
|
284
|
+
kv_cache_dtype=vllm_config.cache_config.cache_dtype))
|
|
285
|
+
if self.is_last_rank:
|
|
286
|
+
self.norm = nnx.RMSNorm(
|
|
287
|
+
hidden_size,
|
|
288
|
+
epsilon=rms_norm_eps,
|
|
289
|
+
param_dtype=dtype,
|
|
290
|
+
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
291
|
+
rngs=rng,
|
|
269
292
|
)
|
|
293
|
+
else:
|
|
294
|
+
self.norm = PPMissingLayer()
|
|
295
|
+
|
|
296
|
+
if self.is_last_rank:
|
|
297
|
+
if model_config.hf_config.tie_word_embeddings:
|
|
298
|
+
self.lm_head = self.embed.embedding
|
|
299
|
+
else:
|
|
300
|
+
self.lm_head = nnx.Param(
|
|
301
|
+
init_fn(rng.params(), (hidden_size, vocab_size), dtype),
|
|
302
|
+
sharding=(None, ShardingAxisName.VOCAB),
|
|
303
|
+
)
|
|
304
|
+
else:
|
|
305
|
+
self.lm_head = PPMissingLayer()
|
|
270
306
|
|
|
271
307
|
self.aux_hidden_state_layers = []
|
|
272
308
|
if vllm_config.speculative_config and vllm_config.speculative_config.method == "eagle3":
|
|
@@ -282,10 +318,18 @@ class LlamaModel(nnx.Module):
|
|
|
282
318
|
kv_caches: List[jax.Array],
|
|
283
319
|
input_ids: jax.Array,
|
|
284
320
|
attention_metadata: AttentionMetadata,
|
|
285
|
-
|
|
286
|
-
|
|
321
|
+
intermediate_tensors: JaxIntermediateTensors | None,
|
|
322
|
+
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
|
|
323
|
+
List[jax.Array], JaxIntermediateTensors]:
|
|
324
|
+
if self.is_first_rank:
|
|
325
|
+
x = self.embed(input_ids)
|
|
326
|
+
else:
|
|
327
|
+
assert intermediate_tensors is not None
|
|
328
|
+
x = intermediate_tensors["hidden_states"]
|
|
329
|
+
|
|
287
330
|
aux_hidden_states = []
|
|
288
|
-
for i, layer in enumerate(
|
|
331
|
+
for i, layer in enumerate(
|
|
332
|
+
islice(self.layers, self.start_layer, self.end_layer)):
|
|
289
333
|
if i in self.aux_hidden_state_layers:
|
|
290
334
|
aux_hidden_states.append(x)
|
|
291
335
|
kv_cache = kv_caches[i]
|
|
@@ -295,6 +339,10 @@ class LlamaModel(nnx.Module):
|
|
|
295
339
|
attention_metadata,
|
|
296
340
|
)
|
|
297
341
|
kv_caches[i] = kv_cache
|
|
342
|
+
if not self.is_last_rank:
|
|
343
|
+
# Note: add aux_hidden_states to make the output spec consistent.
|
|
344
|
+
return kv_caches, JaxIntermediateTensors({"hidden_states":
|
|
345
|
+
x}), aux_hidden_states
|
|
298
346
|
x = self.norm(x)
|
|
299
347
|
return kv_caches, x, aux_hidden_states
|
|
300
348
|
|
|
@@ -313,19 +361,33 @@ class LlamaForCausalLM(nnx.Module):
|
|
|
313
361
|
mesh=mesh,
|
|
314
362
|
)
|
|
315
363
|
|
|
364
|
+
self.pp_missing_layers = []
|
|
365
|
+
for path, module in nnx.iter_graph(self.model):
|
|
366
|
+
if isinstance(module, PPMissingLayer):
|
|
367
|
+
# the path should be sth like ('layers', '0')
|
|
368
|
+
self.pp_missing_layers.append('.'.join([str(s) for s in path]))
|
|
369
|
+
|
|
316
370
|
def __call__(
|
|
317
371
|
self,
|
|
318
372
|
kv_caches: List[jax.Array],
|
|
319
373
|
input_ids: jax.Array,
|
|
320
374
|
attention_metadata: AttentionMetadata,
|
|
375
|
+
_input_embeds=None,
|
|
376
|
+
_input_positions=None,
|
|
377
|
+
_layer_name_to_kv_cache=None,
|
|
378
|
+
_lora_metadata=None,
|
|
379
|
+
intermediate_tensors: JaxIntermediateTensors | None = None,
|
|
380
|
+
_is_first_rank: bool | None = None,
|
|
381
|
+
_is_last_rank: bool | None = None,
|
|
321
382
|
*args,
|
|
322
|
-
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]
|
|
323
|
-
|
|
383
|
+
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
|
|
384
|
+
List[jax.Array], JaxIntermediateTensors]:
|
|
385
|
+
return self.model(
|
|
324
386
|
kv_caches,
|
|
325
387
|
input_ids,
|
|
326
388
|
attention_metadata,
|
|
389
|
+
intermediate_tensors,
|
|
327
390
|
)
|
|
328
|
-
return kv_caches, x, aux_hidden_states
|
|
329
391
|
|
|
330
392
|
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
331
393
|
if self.vllm_config.model_config.hf_config.tie_word_embeddings:
|
|
@@ -373,4 +435,5 @@ class LlamaForCausalLM(nnx.Module):
|
|
|
373
435
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
374
436
|
model=self,
|
|
375
437
|
metadata_map=metadata_map,
|
|
376
|
-
mesh=self.mesh
|
|
438
|
+
mesh=self.mesh,
|
|
439
|
+
pp_missing_layers=self.pp_missing_layers)
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import re
|
|
2
16
|
from typing import List, Optional, Tuple
|
|
3
17
|
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import List, Tuple
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import re
|
|
2
16
|
from typing import Any, List, Optional, Tuple
|
|
3
17
|
|
|
@@ -242,7 +256,7 @@ class LlamaGuard4ForCausalLM(nnx.Module):
|
|
|
242
256
|
self.lm_head.input_embedding_table_DV.value)
|
|
243
257
|
return logits_TV
|
|
244
258
|
|
|
245
|
-
def
|
|
259
|
+
def embed_input_ids(
|
|
246
260
|
self,
|
|
247
261
|
input_ids: jax.Array,
|
|
248
262
|
multimodal_embeddings: Optional[List[jax.Array]] = None
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import List, Optional, Tuple
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -10,6 +24,7 @@ from vllm.config import VllmConfig
|
|
|
10
24
|
from tpu_inference import utils
|
|
11
25
|
from tpu_inference.layers.common.attention_interface import attention
|
|
12
26
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
27
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
13
28
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
29
|
from tpu_inference.logger import init_logger
|
|
15
30
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
@@ -152,8 +167,8 @@ class Qwen2Attention(nnx.Module):
|
|
|
152
167
|
# q_scale = self._q_scale
|
|
153
168
|
k_scale = self._k_scale
|
|
154
169
|
v_scale = self._v_scale
|
|
155
|
-
k, v =
|
|
156
|
-
|
|
170
|
+
k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
|
|
171
|
+
v_scale)
|
|
157
172
|
new_kv_cache, outputs = attention(
|
|
158
173
|
kv_cache,
|
|
159
174
|
q,
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import math
|
|
2
16
|
from functools import partial
|
|
3
17
|
from typing import (Callable, List, Literal, NamedTuple, Optional, TypedDict,
|
|
@@ -996,9 +1010,9 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
996
1010
|
split_indices = np.cumsum(sizes)[:-1]
|
|
997
1011
|
return tuple(jnp.split(image_embeds, split_indices))
|
|
998
1012
|
|
|
999
|
-
def
|
|
1000
|
-
|
|
1001
|
-
|
|
1013
|
+
def embed_multimodal(self, image_grid_thw: tuple[tuple[int, int, int],
|
|
1014
|
+
...],
|
|
1015
|
+
**kwargs: object) -> MultiModalEmbeddings:
|
|
1002
1016
|
|
|
1003
1017
|
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
|
|
1004
1018
|
image_grid_thw, **kwargs)
|
|
@@ -1022,7 +1036,7 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
1022
1036
|
|
|
1023
1037
|
return multimodal_embeddings
|
|
1024
1038
|
|
|
1025
|
-
def
|
|
1039
|
+
def embed_input_ids(
|
|
1026
1040
|
self, input_ids: jax.Array,
|
|
1027
1041
|
multimodal_embeddings: Optional[jax.Array]) -> jax.Array:
|
|
1028
1042
|
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import List, Optional, Tuple
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -10,6 +24,7 @@ from vllm.config import VllmConfig
|
|
|
10
24
|
from tpu_inference import utils
|
|
11
25
|
from tpu_inference.layers.common.attention_interface import attention
|
|
12
26
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
27
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
13
28
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
29
|
from tpu_inference.logger import init_logger
|
|
15
30
|
from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
|
|
@@ -125,8 +140,8 @@ class Qwen3Attention(nnx.Module):
|
|
|
125
140
|
# q_scale = self._q_scale
|
|
126
141
|
k_scale = self._k_scale
|
|
127
142
|
v_scale = self._v_scale
|
|
128
|
-
k, v =
|
|
129
|
-
|
|
143
|
+
k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
|
|
144
|
+
v_scale)
|
|
130
145
|
new_kv_cache, outputs = attention(
|
|
131
146
|
kv_cache,
|
|
132
147
|
q,
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import glob
|
|
2
16
|
import hashlib
|
|
3
17
|
import os
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Union
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -29,25 +43,25 @@ def sanity_check_mm_encoder_outputs(
|
|
|
29
43
|
) -> None:
|
|
30
44
|
"""
|
|
31
45
|
Perform sanity checks for the result of
|
|
32
|
-
[`vllm.model_executor.models.SupportsMultiModal.
|
|
46
|
+
[`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][].
|
|
33
47
|
"""
|
|
34
48
|
assert isinstance(mm_embeddings, (list, tuple, jax.Array)), (
|
|
35
49
|
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
|
|
36
50
|
f"or a single 3D tensor, but got {type(mm_embeddings)} "
|
|
37
51
|
"instead. This is most likely due to incorrect implementation "
|
|
38
|
-
"of the model's `
|
|
52
|
+
"of the model's `embed_multimodal` method.")
|
|
39
53
|
|
|
40
54
|
assert len(mm_embeddings) == expected_num_items, (
|
|
41
55
|
"Expected number of multimodal embeddings to match number of "
|
|
42
56
|
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
|
|
43
57
|
"instead. This is most likely due to incorrect implementation "
|
|
44
|
-
"of the model's `
|
|
58
|
+
"of the model's `embed_multimodal` method.")
|
|
45
59
|
|
|
46
60
|
assert all(e.ndim == 2 for e in mm_embeddings), (
|
|
47
61
|
"Expected multimodal embeddings to be a sequence of 2D tensors, "
|
|
48
62
|
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
|
49
63
|
"instead. This is most likely due to incorrect implementation "
|
|
50
|
-
"of the model's `
|
|
64
|
+
"of the model's `embed_multimodal` method.")
|
|
51
65
|
|
|
52
66
|
|
|
53
67
|
def flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|