tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import List, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
from flax import nnx
|
|
20
|
+
from jax.sharding import Mesh
|
|
21
|
+
from transformers import Qwen3Config
|
|
22
|
+
from vllm.config import VllmConfig
|
|
23
|
+
|
|
24
|
+
from tpu_inference import utils
|
|
25
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
26
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
27
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
28
|
+
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
29
|
+
from tpu_inference.logger import init_logger
|
|
30
|
+
from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
|
|
31
|
+
from tpu_inference.models.jax.qwen2 import Qwen2MLP as Qwen3MLP
|
|
32
|
+
from tpu_inference.models.jax.qwen2 import Qwen2Model
|
|
33
|
+
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
34
|
+
load_hf_weights)
|
|
35
|
+
|
|
36
|
+
logger = init_logger(__name__)
|
|
37
|
+
|
|
38
|
+
init_fn = nnx.initializers.uniform()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Qwen3Attention(nnx.Module):
|
|
42
|
+
|
|
43
|
+
def __init__(self, config: Qwen3Config, dtype: jnp.dtype, rng: nnx.Rngs,
|
|
44
|
+
mesh: Mesh, kv_cache_dtype: str):
|
|
45
|
+
self.hidden_size = config.hidden_size
|
|
46
|
+
self.num_heads = config.num_attention_heads
|
|
47
|
+
self.num_kv_heads = config.num_key_value_heads
|
|
48
|
+
self.rope_theta = config.rope_theta
|
|
49
|
+
self.rope_scaling = getattr(config, "rope_scaling", None)
|
|
50
|
+
self.rms_norm_eps = config.rms_norm_eps
|
|
51
|
+
|
|
52
|
+
self.head_dim_original = getattr(config, "head_dim",
|
|
53
|
+
self.hidden_size // self.num_heads)
|
|
54
|
+
self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
|
|
55
|
+
|
|
56
|
+
sharding_size = mesh.shape["model"]
|
|
57
|
+
self.num_heads = utils.get_padded_num_heads(self.num_heads,
|
|
58
|
+
sharding_size)
|
|
59
|
+
self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
|
|
60
|
+
sharding_size)
|
|
61
|
+
|
|
62
|
+
self.mesh = mesh
|
|
63
|
+
|
|
64
|
+
self.q_proj = nnx.Einsum(
|
|
65
|
+
"TD,DNH->TNH",
|
|
66
|
+
(self.hidden_size, self.num_heads, self.head_dim),
|
|
67
|
+
param_dtype=dtype,
|
|
68
|
+
kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
|
|
69
|
+
rngs=rng,
|
|
70
|
+
)
|
|
71
|
+
self.q_norm = nnx.RMSNorm(
|
|
72
|
+
self.head_dim,
|
|
73
|
+
epsilon=self.rms_norm_eps,
|
|
74
|
+
param_dtype=dtype,
|
|
75
|
+
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
76
|
+
rngs=rng,
|
|
77
|
+
)
|
|
78
|
+
self.k_proj = nnx.Einsum(
|
|
79
|
+
"TD,DKH->TKH",
|
|
80
|
+
(self.hidden_size, self.num_kv_heads, self.head_dim),
|
|
81
|
+
param_dtype=dtype,
|
|
82
|
+
kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
|
|
83
|
+
rngs=rng,
|
|
84
|
+
)
|
|
85
|
+
self.k_norm = nnx.RMSNorm(
|
|
86
|
+
self.head_dim,
|
|
87
|
+
epsilon=self.rms_norm_eps,
|
|
88
|
+
param_dtype=dtype,
|
|
89
|
+
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
90
|
+
rngs=rng,
|
|
91
|
+
)
|
|
92
|
+
self.v_proj = nnx.Einsum(
|
|
93
|
+
"TD,DKH->TKH",
|
|
94
|
+
(self.hidden_size, self.num_kv_heads, self.head_dim),
|
|
95
|
+
param_dtype=dtype,
|
|
96
|
+
kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
|
|
97
|
+
rngs=rng,
|
|
98
|
+
)
|
|
99
|
+
self.o_proj = nnx.Einsum(
|
|
100
|
+
"TNH,NHD->TD",
|
|
101
|
+
(self.num_heads, self.head_dim, self.hidden_size),
|
|
102
|
+
param_dtype=dtype,
|
|
103
|
+
kernel_init=nnx.with_partitioning(init_fn, ("model", None, None)),
|
|
104
|
+
rngs=rng,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
self._q_scale = 1.0
|
|
108
|
+
self._k_scale = 1.0
|
|
109
|
+
self._v_scale = 1.0
|
|
110
|
+
self.kv_cache_quantized_dtype = None
|
|
111
|
+
if kv_cache_dtype != "auto":
|
|
112
|
+
self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
113
|
+
kv_cache_dtype)
|
|
114
|
+
|
|
115
|
+
def __call__(
|
|
116
|
+
self,
|
|
117
|
+
kv_cache: Optional[jax.Array],
|
|
118
|
+
x: jax.Array,
|
|
119
|
+
attention_metadata: AttentionMetadata,
|
|
120
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
121
|
+
md = attention_metadata
|
|
122
|
+
# q: (T, N, H)
|
|
123
|
+
q = self.q_proj(x)
|
|
124
|
+
q = self.q_norm(q)
|
|
125
|
+
q = apply_rope(q, md.input_positions, self.head_dim_original,
|
|
126
|
+
self.rope_theta, self.rope_scaling)
|
|
127
|
+
|
|
128
|
+
# k: (T, K, H)
|
|
129
|
+
k = self.k_proj(x)
|
|
130
|
+
k = self.k_norm(k)
|
|
131
|
+
k = apply_rope(k, md.input_positions, self.head_dim_original,
|
|
132
|
+
self.rope_theta, self.rope_scaling)
|
|
133
|
+
|
|
134
|
+
# v: (T, K, H)
|
|
135
|
+
v = self.v_proj(x)
|
|
136
|
+
# o: (T, N, H)
|
|
137
|
+
q_scale = k_scale = v_scale = None
|
|
138
|
+
if self.kv_cache_quantized_dtype:
|
|
139
|
+
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
140
|
+
# q_scale = self._q_scale
|
|
141
|
+
k_scale = self._k_scale
|
|
142
|
+
v_scale = self._v_scale
|
|
143
|
+
k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
|
|
144
|
+
v_scale)
|
|
145
|
+
new_kv_cache, outputs = attention(
|
|
146
|
+
kv_cache,
|
|
147
|
+
q,
|
|
148
|
+
k,
|
|
149
|
+
v,
|
|
150
|
+
attention_metadata,
|
|
151
|
+
self.mesh,
|
|
152
|
+
self.head_dim_original,
|
|
153
|
+
q_scale=q_scale,
|
|
154
|
+
k_scale=k_scale,
|
|
155
|
+
v_scale=v_scale,
|
|
156
|
+
)
|
|
157
|
+
# (T, D)
|
|
158
|
+
o = self.o_proj(outputs)
|
|
159
|
+
return new_kv_cache, o
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class Qwen3DecoderLayer(Qwen2DecoderLayer):
|
|
163
|
+
|
|
164
|
+
def __init__(self, config: Qwen3Config, dtype: jnp.dtype, rng: nnx.Rngs,
|
|
165
|
+
mesh: Mesh, kv_cache_dtype: str):
|
|
166
|
+
rms_norm_eps = config.rms_norm_eps
|
|
167
|
+
hidden_size = config.hidden_size
|
|
168
|
+
|
|
169
|
+
self.input_layernorm = nnx.RMSNorm(
|
|
170
|
+
hidden_size,
|
|
171
|
+
epsilon=rms_norm_eps,
|
|
172
|
+
param_dtype=dtype,
|
|
173
|
+
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
174
|
+
rngs=rng,
|
|
175
|
+
)
|
|
176
|
+
self.self_attn = Qwen3Attention(config=config,
|
|
177
|
+
dtype=dtype,
|
|
178
|
+
rng=rng,
|
|
179
|
+
mesh=mesh,
|
|
180
|
+
kv_cache_dtype=kv_cache_dtype)
|
|
181
|
+
self.post_attention_layernorm = nnx.RMSNorm(
|
|
182
|
+
hidden_size,
|
|
183
|
+
epsilon=rms_norm_eps,
|
|
184
|
+
param_dtype=dtype,
|
|
185
|
+
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
186
|
+
rngs=rng,
|
|
187
|
+
)
|
|
188
|
+
self.mlp = Qwen3MLP(
|
|
189
|
+
config=config,
|
|
190
|
+
dtype=dtype,
|
|
191
|
+
rng=rng,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class Qwen3Model(Qwen2Model):
|
|
196
|
+
|
|
197
|
+
def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
|
|
198
|
+
mesh: Mesh) -> None:
|
|
199
|
+
model_config = vllm_config.model_config
|
|
200
|
+
hf_config = model_config.hf_config
|
|
201
|
+
vocab_size = model_config.get_vocab_size()
|
|
202
|
+
dtype = model_config.dtype
|
|
203
|
+
rms_norm_eps = hf_config.rms_norm_eps
|
|
204
|
+
hidden_size = hf_config.hidden_size
|
|
205
|
+
|
|
206
|
+
self.embed = nnx.Embed(
|
|
207
|
+
num_embeddings=vocab_size,
|
|
208
|
+
features=hidden_size,
|
|
209
|
+
param_dtype=dtype,
|
|
210
|
+
embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
|
|
211
|
+
rngs=rng,
|
|
212
|
+
)
|
|
213
|
+
self.layers = [
|
|
214
|
+
Qwen3DecoderLayer(
|
|
215
|
+
config=hf_config,
|
|
216
|
+
dtype=dtype,
|
|
217
|
+
rng=rng,
|
|
218
|
+
mesh=mesh,
|
|
219
|
+
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
220
|
+
kv_cache_dtype=vllm_config.cache_config.cache_dtype)
|
|
221
|
+
for _ in range(hf_config.num_hidden_layers)
|
|
222
|
+
]
|
|
223
|
+
self.norm = nnx.RMSNorm(
|
|
224
|
+
hidden_size,
|
|
225
|
+
epsilon=rms_norm_eps,
|
|
226
|
+
param_dtype=dtype,
|
|
227
|
+
scale_init=nnx.with_partitioning(init_fn, (None, )),
|
|
228
|
+
rngs=rng,
|
|
229
|
+
)
|
|
230
|
+
if model_config.hf_config.tie_word_embeddings:
|
|
231
|
+
self.lm_head = self.embed.embedding
|
|
232
|
+
else:
|
|
233
|
+
self.lm_head = nnx.Param(
|
|
234
|
+
init_fn(rng.params(), (hidden_size, vocab_size), dtype),
|
|
235
|
+
sharding=(None, "model"),
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class Qwen3ForCausalLM(nnx.Module):
|
|
240
|
+
|
|
241
|
+
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
|
|
242
|
+
mesh: Mesh) -> None:
|
|
243
|
+
self.vllm_config = vllm_config
|
|
244
|
+
self.rng = nnx.Rngs(rng_key)
|
|
245
|
+
self.mesh = mesh
|
|
246
|
+
|
|
247
|
+
self.model = Qwen3Model(
|
|
248
|
+
vllm_config=vllm_config,
|
|
249
|
+
rng=self.rng,
|
|
250
|
+
mesh=mesh,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
def __call__(
|
|
254
|
+
self,
|
|
255
|
+
kv_caches: List[jax.Array],
|
|
256
|
+
input_ids: jax.Array,
|
|
257
|
+
attention_metadata: AttentionMetadata,
|
|
258
|
+
*args,
|
|
259
|
+
) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
|
|
260
|
+
kv_caches, x = self.model(
|
|
261
|
+
kv_caches,
|
|
262
|
+
input_ids,
|
|
263
|
+
attention_metadata,
|
|
264
|
+
)
|
|
265
|
+
return kv_caches, x, []
|
|
266
|
+
|
|
267
|
+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
268
|
+
if self.vllm_config.model_config.hf_config.tie_word_embeddings:
|
|
269
|
+
logits = jnp.dot(hidden_states, self.model.lm_head.value.T)
|
|
270
|
+
else:
|
|
271
|
+
logits = jnp.dot(hidden_states, self.model.lm_head.value)
|
|
272
|
+
return logits
|
|
273
|
+
|
|
274
|
+
def load_weights(self, rng_key: jax.Array):
|
|
275
|
+
# NOTE: Since we are using nnx.eval_shape to init the model,
|
|
276
|
+
# we have to pass dynamic arrays here for __call__'s usage.
|
|
277
|
+
self.rng = nnx.Rngs(rng_key)
|
|
278
|
+
|
|
279
|
+
# Key: path to a HF layer weight
|
|
280
|
+
# Value: path to a nnx layer weight
|
|
281
|
+
mappings = {
|
|
282
|
+
"model.embed_tokens": "model.embed.embedding",
|
|
283
|
+
"model.layers.*.input_layernorm":
|
|
284
|
+
"model.layers.*.input_layernorm.scale",
|
|
285
|
+
"model.layers.*.mlp.down_proj":
|
|
286
|
+
"model.layers.*.mlp.down_proj.kernel",
|
|
287
|
+
"model.layers.*.mlp.gate_proj":
|
|
288
|
+
"model.layers.*.mlp.gate_proj.kernel",
|
|
289
|
+
"model.layers.*.mlp.up_proj": "model.layers.*.mlp.up_proj.kernel",
|
|
290
|
+
"model.layers.*.post_attention_layernorm":
|
|
291
|
+
"model.layers.*.post_attention_layernorm.scale",
|
|
292
|
+
"model.layers.*.self_attn.k_norm":
|
|
293
|
+
"model.layers.*.self_attn.k_norm.scale",
|
|
294
|
+
"model.layers.*.self_attn.k_proj":
|
|
295
|
+
"model.layers.*.self_attn.k_proj.kernel",
|
|
296
|
+
"model.layers.*.self_attn.o_proj":
|
|
297
|
+
"model.layers.*.self_attn.o_proj.kernel",
|
|
298
|
+
"model.layers.*.self_attn.q_norm":
|
|
299
|
+
"model.layers.*.self_attn.q_norm.scale",
|
|
300
|
+
"model.layers.*.self_attn.q_proj":
|
|
301
|
+
"model.layers.*.self_attn.q_proj.kernel",
|
|
302
|
+
"model.layers.*.self_attn.v_proj":
|
|
303
|
+
"model.layers.*.self_attn.v_proj.kernel",
|
|
304
|
+
"model.norm": "model.norm.scale",
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
# Add lm_head mapping only if it's not tied to embeddings
|
|
308
|
+
if not self.vllm_config.model_config.hf_config.tie_word_embeddings:
|
|
309
|
+
mappings.update({
|
|
310
|
+
"lm_head": "model.lm_head",
|
|
311
|
+
})
|
|
312
|
+
|
|
313
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
314
|
+
self.mesh, mappings)
|
|
315
|
+
load_hf_weights(vllm_config=self.vllm_config,
|
|
316
|
+
model=self,
|
|
317
|
+
metadata_map=metadata_map,
|
|
318
|
+
mesh=self.mesh)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import glob
|
|
16
|
+
import hashlib
|
|
17
|
+
import os
|
|
18
|
+
import shutil
|
|
19
|
+
import subprocess
|
|
20
|
+
from typing import List, Optional
|
|
21
|
+
|
|
22
|
+
import filelock
|
|
23
|
+
import huggingface_hub.constants
|
|
24
|
+
from huggingface_hub import HfFileSystem, snapshot_download
|
|
25
|
+
from tqdm.auto import tqdm
|
|
26
|
+
|
|
27
|
+
from tpu_inference.logger import init_logger
|
|
28
|
+
|
|
29
|
+
logger = init_logger(__name__)
|
|
30
|
+
# Do not set the HuggingFace token here, it should be set via the env `HF_TOKEN`.
|
|
31
|
+
hfs = HfFileSystem()
|
|
32
|
+
|
|
33
|
+
LOCK_DIR = "/tmp/lock"
|
|
34
|
+
|
|
35
|
+
##### Local file utils #####
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def run_cmd(cmd: str, *args, **kwargs) -> subprocess.CompletedProcess:
|
|
39
|
+
return subprocess.run(cmd.split(), *args, **kwargs)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def delete_file(path: str) -> None:
|
|
43
|
+
if os.path.isfile(path):
|
|
44
|
+
os.remove(path)
|
|
45
|
+
else:
|
|
46
|
+
logger.error(f"Trying to delete non-existing file: {path}")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def list_files(dir: str, pattern: str = "*") -> List[str]:
|
|
50
|
+
files = glob.glob(os.path.join(dir, pattern))
|
|
51
|
+
return files
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_lock(model_name_or_path: str):
|
|
55
|
+
lock_dir = LOCK_DIR
|
|
56
|
+
model_name_or_path = str(model_name_or_path)
|
|
57
|
+
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
|
|
58
|
+
model_name = model_name_or_path.replace("/", "-")
|
|
59
|
+
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
|
|
60
|
+
# add hash to avoid conflict with old users' lock files
|
|
61
|
+
lock_file_name = hash_name + model_name + ".lock"
|
|
62
|
+
# mode 0o666 is required for the filelock to be shared across users
|
|
63
|
+
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
|
|
64
|
+
mode=0o666)
|
|
65
|
+
return lock
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_free_disk_size(path: str = "/") -> int:
|
|
69
|
+
free_bytes = shutil.disk_usage(path)[2]
|
|
70
|
+
return free_bytes
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
##### HuggingFace file utils #####
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def is_hf_repo(repo_id: str) -> bool:
|
|
77
|
+
return hfs.exists(repo_id)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def list_hf_repo(repo_id: str, pattern: str = "**") -> List[str]:
|
|
81
|
+
repo_files = hfs.glob(os.path.join(repo_id, pattern))
|
|
82
|
+
return repo_files
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_hf_model_weights_size(repo_id: str, weights_format: str) -> int:
|
|
86
|
+
weights_paths = list_hf_repo(repo_id, weights_format)
|
|
87
|
+
weights_size = 0
|
|
88
|
+
for weights_path in weights_paths:
|
|
89
|
+
weights_size += int(hfs.info(weights_path)["size"])
|
|
90
|
+
return weights_size
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class DisabledTqdm(tqdm):
|
|
94
|
+
|
|
95
|
+
def __init__(self, *args, **kwargs):
|
|
96
|
+
super().__init__(*args, **kwargs, disable=True)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def download_model_weights_from_hf(model_path: str, cache_dir: Optional[str],
|
|
100
|
+
weights_format: str) -> str:
|
|
101
|
+
with get_lock(model_path):
|
|
102
|
+
local_dir = snapshot_download(
|
|
103
|
+
model_path,
|
|
104
|
+
cache_dir=cache_dir, # can be specified by HF_HOME or HF_HUB_CACHE
|
|
105
|
+
allow_patterns=weights_format,
|
|
106
|
+
tqdm_class=DisabledTqdm,
|
|
107
|
+
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
|
108
|
+
)
|
|
109
|
+
local_files = list_files(local_dir, weights_format)
|
|
110
|
+
return local_files
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Union
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
from typing_extensions import TypeAlias
|
|
20
|
+
from vllm.logger import init_logger
|
|
21
|
+
|
|
22
|
+
logger = init_logger(__name__)
|
|
23
|
+
|
|
24
|
+
NestedTensors: TypeAlias = Union[list["NestedTensors"], list["jax.Array"],
|
|
25
|
+
"jax.Array", tuple["jax.Array", ...]]
|
|
26
|
+
"""
|
|
27
|
+
Uses a list instead of a tensor if the dimensions of each element do not match.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
MultiModalEmbeddings = Union[list[jax.Array], jax.Array, tuple[jax.Array, ...]]
|
|
31
|
+
"""
|
|
32
|
+
The output embeddings must be one of the following formats:
|
|
33
|
+
|
|
34
|
+
- A list or tuple of 2D tensors, where each tensor corresponds to
|
|
35
|
+
each input multimodal data item (e.g, image).
|
|
36
|
+
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def sanity_check_mm_encoder_outputs(
|
|
41
|
+
mm_embeddings: MultiModalEmbeddings,
|
|
42
|
+
expected_num_items: int,
|
|
43
|
+
) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Perform sanity checks for the result of
|
|
46
|
+
[`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
|
|
47
|
+
"""
|
|
48
|
+
assert isinstance(mm_embeddings, (list, tuple, jax.Array)), (
|
|
49
|
+
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
|
|
50
|
+
f"or a single 3D tensor, but got {type(mm_embeddings)} "
|
|
51
|
+
"instead. This is most likely due to incorrect implementation "
|
|
52
|
+
"of the model's `get_multimodal_embeddings` method.")
|
|
53
|
+
|
|
54
|
+
assert len(mm_embeddings) == expected_num_items, (
|
|
55
|
+
"Expected number of multimodal embeddings to match number of "
|
|
56
|
+
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
|
|
57
|
+
"instead. This is most likely due to incorrect implementation "
|
|
58
|
+
"of the model's `get_multimodal_embeddings` method.")
|
|
59
|
+
|
|
60
|
+
assert all(e.ndim == 2 for e in mm_embeddings), (
|
|
61
|
+
"Expected multimodal embeddings to be a sequence of 2D tensors, "
|
|
62
|
+
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
|
63
|
+
"instead. This is most likely due to incorrect implementation "
|
|
64
|
+
"of the model's `get_multimodal_embeddings` method.")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
|
|
68
|
+
"""
|
|
69
|
+
Recursively flattens and concatenates NestedTensors on all but the last
|
|
70
|
+
dimension.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
if isinstance(embeddings, jax.Array):
|
|
74
|
+
return embeddings.reshape(-1, embeddings.shape[-1])
|
|
75
|
+
|
|
76
|
+
return jnp.concatenate([flatten_embeddings(t) for t in embeddings], axis=0)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _embedding_count_expression(embeddings: NestedTensors) -> str:
|
|
80
|
+
"""
|
|
81
|
+
Constructs a debugging representation of the number of embeddings in the
|
|
82
|
+
NestedTensors.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
if isinstance(embeddings, jax.Array):
|
|
86
|
+
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
|
|
87
|
+
|
|
88
|
+
return " + ".join(
|
|
89
|
+
_embedding_count_expression(inner) for inner in embeddings)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _merge_multimodal_embeddings(
|
|
93
|
+
inputs_embeds: jax.Array,
|
|
94
|
+
is_multimodal: jax.Array,
|
|
95
|
+
multimodal_embeddings: jax.Array,
|
|
96
|
+
) -> jax.Array:
|
|
97
|
+
"""
|
|
98
|
+
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
|
99
|
+
positions in ``inputs_embeds`` corresponding to placeholder tokens in
|
|
100
|
+
``input_ids``.
|
|
101
|
+
This returns a new array with the updated values.
|
|
102
|
+
Note:
|
|
103
|
+
This returns a new array with the updated values.
|
|
104
|
+
"""
|
|
105
|
+
# The check for matching number of tokens is removed as it is not
|
|
106
|
+
# JIT-compatible. If the shapes mismatch, JAX will raise an error
|
|
107
|
+
# during execution anyway. The user-friendly error message is
|
|
108
|
+
# sacrificed for JIT compatibility.
|
|
109
|
+
|
|
110
|
+
# JIT-compatible implementation using jnp.where to avoid
|
|
111
|
+
# NonConcreteBooleanIndexError.
|
|
112
|
+
# Create a dummy row to handle indices for non-multimodal tokens.
|
|
113
|
+
# The content of the dummy row does not matter as it will be masked out.
|
|
114
|
+
dummy_row = jnp.zeros_like(multimodal_embeddings[0:1])
|
|
115
|
+
|
|
116
|
+
# Prepend the dummy row to the flattened embeddings.
|
|
117
|
+
flattened_padded = jnp.concatenate([dummy_row, multimodal_embeddings],
|
|
118
|
+
axis=0)
|
|
119
|
+
|
|
120
|
+
# Create gather indices. For each token in the input sequence, this gives
|
|
121
|
+
# the index into `flattened_padded`.
|
|
122
|
+
# For non-multimodal tokens, the index will be 0 (pointing to the dummy
|
|
123
|
+
# row). For the k-th multimodal token, the index will be k.
|
|
124
|
+
gather_indices = jnp.cumsum(is_multimodal)
|
|
125
|
+
|
|
126
|
+
# Gather the embeddings to be placed.
|
|
127
|
+
update_values = flattened_padded[gather_indices]
|
|
128
|
+
|
|
129
|
+
# Use jnp.where to select between original and new embeddings.
|
|
130
|
+
condition = jnp.expand_dims(is_multimodal, axis=-1)
|
|
131
|
+
return jnp.where(condition, update_values, inputs_embeds)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def merge_multimodal_embeddings(
|
|
135
|
+
input_ids: jax.Array,
|
|
136
|
+
inputs_embeds: jax.Array,
|
|
137
|
+
multimodal_embeddings: jax.Array,
|
|
138
|
+
placeholder_token_id: Union[int, list[int]],
|
|
139
|
+
) -> jax.Array:
|
|
140
|
+
"""
|
|
141
|
+
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
|
142
|
+
positions in ``inputs_embeds`` corresponding to placeholder tokens in
|
|
143
|
+
``input_ids``.
|
|
144
|
+
|
|
145
|
+
``placeholder_token_id`` can be a list of token ids (e.g, token ids
|
|
146
|
+
of img_start, img_break, and img_end tokens) when needed: This means
|
|
147
|
+
the order of these tokens in the ``input_ids`` MUST MATCH the order of
|
|
148
|
+
their embeddings in ``multimodal_embeddings`` since we need to
|
|
149
|
+
slice-merge instead of individually scattering.
|
|
150
|
+
|
|
151
|
+
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
|
|
152
|
+
- T is text token
|
|
153
|
+
- S is image start token
|
|
154
|
+
- I is image embedding token
|
|
155
|
+
- B is image break token
|
|
156
|
+
- E is image end token.
|
|
157
|
+
|
|
158
|
+
Then the image embeddings (that correspond to I's) from vision encoder
|
|
159
|
+
must be padded with embeddings of S, B, and E in the same order of
|
|
160
|
+
input_ids for a correct embedding merge.
|
|
161
|
+
|
|
162
|
+
This returns a new array with the updated values.
|
|
163
|
+
"""
|
|
164
|
+
if isinstance(placeholder_token_id, list):
|
|
165
|
+
placeholder_token_id = jnp.array(placeholder_token_id)
|
|
166
|
+
|
|
167
|
+
return _merge_multimodal_embeddings(
|
|
168
|
+
inputs_embeds,
|
|
169
|
+
jnp.isin(input_ids, placeholder_token_id),
|
|
170
|
+
multimodal_embeddings,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return _merge_multimodal_embeddings(
|
|
174
|
+
inputs_embeds,
|
|
175
|
+
(input_ids == placeholder_token_id),
|
|
176
|
+
multimodal_embeddings,
|
|
177
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|