tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
from typing import Any, Optional
|
|
3
17
|
|
|
@@ -5,22 +19,31 @@ import jax
|
|
|
5
19
|
import torch
|
|
6
20
|
from flax import nnx
|
|
7
21
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
-
from torchax.ops.mappings import j2t_dtype
|
|
9
22
|
from transformers import PretrainedConfig
|
|
10
23
|
from vllm.config import VllmConfig
|
|
24
|
+
from vllm.model_executor.model_loader import get_model_loader
|
|
25
|
+
from vllm.model_executor.model_loader.runai_streamer_loader import \
|
|
26
|
+
RunaiModelStreamerLoader
|
|
11
27
|
from vllm.utils.func_utils import supports_kw
|
|
12
28
|
|
|
13
29
|
from tpu_inference import envs
|
|
14
30
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
15
31
|
from tpu_inference.logger import init_logger
|
|
16
|
-
from tpu_inference.models.jax.utils.
|
|
32
|
+
from tpu_inference.models.jax.utils.qwix.qwix_utils import (
|
|
17
33
|
apply_qwix_on_abstract_model, apply_qwix_quantization,
|
|
18
|
-
load_random_weights_into_qwix_abstract_model
|
|
34
|
+
load_random_weights_into_qwix_abstract_model,
|
|
35
|
+
update_vllm_config_for_qwix_quantization)
|
|
36
|
+
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
|
|
19
37
|
|
|
20
38
|
logger = init_logger(__name__)
|
|
21
39
|
|
|
22
40
|
_MODEL_REGISTRY = {}
|
|
23
41
|
|
|
42
|
+
# List of architectures that are preferred to use "vllm" implementation over
|
|
43
|
+
# "flax_nnx" implementation due to various factors such as performance.
|
|
44
|
+
_VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
|
|
45
|
+
{"GptOssForCausalLM"})
|
|
46
|
+
|
|
24
47
|
|
|
25
48
|
class UnsupportedArchitectureError(ValueError):
|
|
26
49
|
"""Raised when a model architecture is not supported in the registry."""
|
|
@@ -177,7 +200,23 @@ def _get_nnx_model(
|
|
|
177
200
|
# the model creation again, otherwise the model forward will have
|
|
178
201
|
# non-trivial overhead in PjitFunction.
|
|
179
202
|
with mesh:
|
|
180
|
-
|
|
203
|
+
loader = get_model_loader(vllm_config.load_config)
|
|
204
|
+
if isinstance(loader, RunaiModelStreamerLoader):
|
|
205
|
+
model_weights = vllm_config.model_config.model
|
|
206
|
+
if hasattr(vllm_config.model_config, "model_weights"):
|
|
207
|
+
model_weights = vllm_config.model_config.model_weights
|
|
208
|
+
weights_iterator = loader._get_weights_iterator(
|
|
209
|
+
model_weights, vllm_config.model_config.revision)
|
|
210
|
+
# We set the weights iterator at runtime, to prevent having to change
|
|
211
|
+
# every model's load_weights signature. This also prevents us from hitting
|
|
212
|
+
# a TypeError at runtime if you use the RunaiModelStreamerLoader with any
|
|
213
|
+
# flax_nnx model whose load_weights function does not accept the
|
|
214
|
+
# weights_iterator keyword argument.
|
|
215
|
+
vllm_config.model_config.model_weights_iterator = weights_iterator
|
|
216
|
+
model.load_weights(rng)
|
|
217
|
+
del vllm_config.model_config.model_weights_iterator
|
|
218
|
+
else:
|
|
219
|
+
model.load_weights(rng)
|
|
181
220
|
jit_model = create_jit_model(
|
|
182
221
|
model,
|
|
183
222
|
use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
|
|
@@ -191,6 +230,13 @@ def get_flax_model(
|
|
|
191
230
|
mesh: Mesh,
|
|
192
231
|
is_draft_model: bool = False,
|
|
193
232
|
) -> nnx.Module:
|
|
233
|
+
model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
|
|
234
|
+
vllm_config.model_config.dtype = model_dtype
|
|
235
|
+
|
|
236
|
+
# Only perform qwix quantization if it is jax model.
|
|
237
|
+
if vllm_config.model_config:
|
|
238
|
+
update_vllm_config_for_qwix_quantization(vllm_config)
|
|
239
|
+
|
|
194
240
|
if is_draft_model:
|
|
195
241
|
model_class = _get_model_architecture(
|
|
196
242
|
vllm_config.speculative_config.draft_model_config.hf_config)
|
|
@@ -199,7 +245,9 @@ def get_flax_model(
|
|
|
199
245
|
vllm_config.model_config.hf_config)
|
|
200
246
|
jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
|
|
201
247
|
kv_cache_sharding = NamedSharding(
|
|
202
|
-
mesh,
|
|
248
|
+
mesh,
|
|
249
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
|
|
250
|
+
ShardingAxisName.ATTN_HEAD))
|
|
203
251
|
hidden_states_sharding = NamedSharding(mesh,
|
|
204
252
|
PartitionSpec(
|
|
205
253
|
ShardingAxisName.ATTN_DATA,
|
|
@@ -217,14 +265,17 @@ def get_flax_model(
|
|
|
217
265
|
hidden_states_sharding, # aux hidden states
|
|
218
266
|
),
|
|
219
267
|
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
|
|
220
|
-
static_argnums=
|
|
268
|
+
static_argnums=(
|
|
269
|
+
7, 10, 11
|
|
270
|
+
), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
|
|
221
271
|
)
|
|
222
272
|
def run_model(graphdef, state, *args):
|
|
223
273
|
model = nnx.merge(graphdef, state)
|
|
224
274
|
return model(*args)
|
|
225
275
|
|
|
226
276
|
logits_sharding = NamedSharding(
|
|
227
|
-
mesh,
|
|
277
|
+
mesh,
|
|
278
|
+
PartitionSpec(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR))
|
|
228
279
|
|
|
229
280
|
@functools.partial(
|
|
230
281
|
jax.jit,
|
|
@@ -237,10 +288,9 @@ def get_flax_model(
|
|
|
237
288
|
|
|
238
289
|
# Multi-modal support only
|
|
239
290
|
# This function calculates the image token's embeddings by VIT
|
|
240
|
-
def
|
|
241
|
-
**kwargs):
|
|
291
|
+
def run_embed_multimodal(graphdef, state, image_grid_thw, **kwargs):
|
|
242
292
|
model = nnx.merge(graphdef, state)
|
|
243
|
-
return model.
|
|
293
|
+
return model.embed_multimodal(image_grid_thw, **kwargs)
|
|
244
294
|
|
|
245
295
|
embed_sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
246
296
|
# This function will calculates the embeddings of input texts and then merge with the image embeddings
|
|
@@ -248,9 +298,9 @@ def get_flax_model(
|
|
|
248
298
|
jax.jit,
|
|
249
299
|
out_shardings=(embed_sharding),
|
|
250
300
|
)
|
|
251
|
-
def
|
|
301
|
+
def run_embed_input_ids(graphdef, state, *args, **kwargs):
|
|
252
302
|
model = nnx.merge(graphdef, state)
|
|
253
|
-
return model.
|
|
303
|
+
return model.embed_input_ids(*args, **kwargs)
|
|
254
304
|
|
|
255
305
|
# For models that want to work with EAGLE-3 speculative decoding
|
|
256
306
|
@functools.partial(
|
|
@@ -266,10 +316,8 @@ def get_flax_model(
|
|
|
266
316
|
None)
|
|
267
317
|
model_fn = functools.partial(run_model, graphdef)
|
|
268
318
|
compute_logits_fn = functools.partial(run_compute_logits, graphdef)
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
|
|
272
|
-
graphdef)
|
|
319
|
+
embed_multimodal_fn = functools.partial(run_embed_multimodal, graphdef)
|
|
320
|
+
embed_input_ids_fn = functools.partial(run_embed_input_ids, graphdef)
|
|
273
321
|
lora_manager, model = None, None
|
|
274
322
|
combine_hidden_states_fn = functools.partial(combine_hidden_states,
|
|
275
323
|
graphdef)
|
|
@@ -280,8 +328,8 @@ def get_flax_model(
|
|
|
280
328
|
|
|
281
329
|
multimodal_fns = {
|
|
282
330
|
"precompile_vision_encoder_fn": precompile_vision_encoder_fn,
|
|
283
|
-
"
|
|
284
|
-
"
|
|
331
|
+
"embed_multimodal_fn": embed_multimodal_fn,
|
|
332
|
+
"embed_input_ids_fn": embed_input_ids_fn,
|
|
285
333
|
"get_mrope_input_positions_fn": get_mrope_input_positions_fn,
|
|
286
334
|
}
|
|
287
335
|
|
|
@@ -293,6 +341,8 @@ def get_vllm_model(
|
|
|
293
341
|
rng: jax.Array,
|
|
294
342
|
mesh: Mesh,
|
|
295
343
|
):
|
|
344
|
+
model_dtype = to_torch_dtype(vllm_config.model_config.dtype)
|
|
345
|
+
vllm_config.model_config.dtype = model_dtype
|
|
296
346
|
from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper
|
|
297
347
|
|
|
298
348
|
model = VllmModelWrapper(
|
|
@@ -318,24 +368,39 @@ def get_model(
|
|
|
318
368
|
impl = envs.MODEL_IMPL_TYPE
|
|
319
369
|
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
|
|
320
370
|
|
|
321
|
-
if impl == "
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
371
|
+
if impl == "auto":
|
|
372
|
+
# Resolve "auto" based on architecture
|
|
373
|
+
architectures = getattr(vllm_config.model_config.hf_config,
|
|
374
|
+
"architectures", [])
|
|
375
|
+
assert len(architectures) == 1, (
|
|
376
|
+
f"Expected exactly one architecture, got {len(architectures)}: "
|
|
377
|
+
f"{architectures}")
|
|
378
|
+
arch = architectures[0]
|
|
379
|
+
impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
|
|
380
|
+
logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
|
|
381
|
+
|
|
382
|
+
match impl:
|
|
383
|
+
case "flax_nnx":
|
|
384
|
+
if vllm_config.parallel_config.pipeline_parallel_size > 1:
|
|
385
|
+
logger.warning(
|
|
386
|
+
"PP is not fully supported on Jax flax_nnx models yet, fallback to vllm models."
|
|
387
|
+
)
|
|
388
|
+
return get_vllm_model(vllm_config, rng, mesh)
|
|
389
|
+
try:
|
|
390
|
+
# Try to load the flax model first
|
|
391
|
+
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
|
|
392
|
+
except UnsupportedArchitectureError as e:
|
|
393
|
+
# Convert the error message to a string to check its contents
|
|
394
|
+
error_msg = str(e)
|
|
395
|
+
|
|
396
|
+
logger.warning(error_msg)
|
|
397
|
+
|
|
398
|
+
# Fall back to the vLLM model and updating the dtype accordingly
|
|
399
|
+
return get_vllm_model(vllm_config, rng, mesh)
|
|
400
|
+
case "vllm":
|
|
334
401
|
return get_vllm_model(vllm_config, rng, mesh)
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
else:
|
|
338
|
-
raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
|
|
402
|
+
case _:
|
|
403
|
+
raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
|
|
339
404
|
|
|
340
405
|
|
|
341
406
|
def _validate_model_interface(model: Any) -> None:
|
|
@@ -421,6 +486,17 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
421
486
|
"This is a JAX model and does not implement the PyTorch forward method."
|
|
422
487
|
)
|
|
423
488
|
|
|
489
|
+
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
|
|
490
|
+
def unimplemented_embed_input_ids(
|
|
491
|
+
self,
|
|
492
|
+
input_ids: "torch.Tensor",
|
|
493
|
+
positions: "torch.Tensor",
|
|
494
|
+
inputs_embeds: Optional["torch.Tensor"] = None,
|
|
495
|
+
) -> "torch.Tensor":
|
|
496
|
+
raise NotImplementedError(
|
|
497
|
+
"This is a JAX model and does not implement the PyTorch embed_input_ids method."
|
|
498
|
+
)
|
|
499
|
+
|
|
424
500
|
# We need a custom __init__ that only calls torch.nn.Module's init,
|
|
425
501
|
# to avoid triggering JAX logic when vLLM inspects the class.
|
|
426
502
|
def wrapper_init(self, *args, **kwargs):
|
|
@@ -434,6 +510,7 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
434
510
|
{
|
|
435
511
|
"__init__": wrapper_init,
|
|
436
512
|
"forward": unimplemented_forward,
|
|
513
|
+
"embed_input_ids": unimplemented_embed_input_ids,
|
|
437
514
|
# Prevent vLLM from trying to load weights into this dummy class.
|
|
438
515
|
"load_weights": lambda self, *args, **kwargs: None,
|
|
439
516
|
})
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|