tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
from typing import Any, Optional
|
|
3
17
|
|
|
@@ -5,7 +19,6 @@ import jax
|
|
|
5
19
|
import torch
|
|
6
20
|
from flax import nnx
|
|
7
21
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
-
from torchax.ops.mappings import j2t_dtype
|
|
9
22
|
from transformers import PretrainedConfig
|
|
10
23
|
from vllm.config import VllmConfig
|
|
11
24
|
from vllm.model_executor.model_loader import get_model_loader
|
|
@@ -16,14 +29,20 @@ from vllm.utils.func_utils import supports_kw
|
|
|
16
29
|
from tpu_inference import envs
|
|
17
30
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
18
31
|
from tpu_inference.logger import init_logger
|
|
19
|
-
from tpu_inference.models.jax.utils.
|
|
32
|
+
from tpu_inference.models.jax.utils.qwix.qwix_utils import (
|
|
20
33
|
apply_qwix_on_abstract_model, apply_qwix_quantization,
|
|
21
34
|
load_random_weights_into_qwix_abstract_model)
|
|
35
|
+
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
|
|
22
36
|
|
|
23
37
|
logger = init_logger(__name__)
|
|
24
38
|
|
|
25
39
|
_MODEL_REGISTRY = {}
|
|
26
40
|
|
|
41
|
+
# List of architectures that are preferred to use "vllm" implementation over
|
|
42
|
+
# "flax_nnx" implementation due to various factors such as performance.
|
|
43
|
+
_VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
|
|
44
|
+
{"GptOssForCausalLM"})
|
|
45
|
+
|
|
27
46
|
|
|
28
47
|
class UnsupportedArchitectureError(ValueError):
|
|
29
48
|
"""Raised when a model architecture is not supported in the registry."""
|
|
@@ -210,6 +229,9 @@ def get_flax_model(
|
|
|
210
229
|
mesh: Mesh,
|
|
211
230
|
is_draft_model: bool = False,
|
|
212
231
|
) -> nnx.Module:
|
|
232
|
+
model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
|
|
233
|
+
vllm_config.model_config.dtype = model_dtype
|
|
234
|
+
|
|
213
235
|
if is_draft_model:
|
|
214
236
|
model_class = _get_model_architecture(
|
|
215
237
|
vllm_config.speculative_config.draft_model_config.hf_config)
|
|
@@ -218,7 +240,9 @@ def get_flax_model(
|
|
|
218
240
|
vllm_config.model_config.hf_config)
|
|
219
241
|
jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
|
|
220
242
|
kv_cache_sharding = NamedSharding(
|
|
221
|
-
mesh,
|
|
243
|
+
mesh,
|
|
244
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
|
|
245
|
+
ShardingAxisName.ATTN_HEAD))
|
|
222
246
|
hidden_states_sharding = NamedSharding(mesh,
|
|
223
247
|
PartitionSpec(
|
|
224
248
|
ShardingAxisName.ATTN_DATA,
|
|
@@ -236,14 +260,17 @@ def get_flax_model(
|
|
|
236
260
|
hidden_states_sharding, # aux hidden states
|
|
237
261
|
),
|
|
238
262
|
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
|
|
239
|
-
static_argnums=
|
|
263
|
+
static_argnums=(
|
|
264
|
+
7, 10, 11
|
|
265
|
+
), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
|
|
240
266
|
)
|
|
241
267
|
def run_model(graphdef, state, *args):
|
|
242
268
|
model = nnx.merge(graphdef, state)
|
|
243
269
|
return model(*args)
|
|
244
270
|
|
|
245
271
|
logits_sharding = NamedSharding(
|
|
246
|
-
mesh,
|
|
272
|
+
mesh,
|
|
273
|
+
PartitionSpec(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR))
|
|
247
274
|
|
|
248
275
|
@functools.partial(
|
|
249
276
|
jax.jit,
|
|
@@ -256,10 +283,9 @@ def get_flax_model(
|
|
|
256
283
|
|
|
257
284
|
# Multi-modal support only
|
|
258
285
|
# This function calculates the image token's embeddings by VIT
|
|
259
|
-
def
|
|
260
|
-
**kwargs):
|
|
286
|
+
def run_embed_multimodal(graphdef, state, image_grid_thw, **kwargs):
|
|
261
287
|
model = nnx.merge(graphdef, state)
|
|
262
|
-
return model.
|
|
288
|
+
return model.embed_multimodal(image_grid_thw, **kwargs)
|
|
263
289
|
|
|
264
290
|
embed_sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
265
291
|
# This function will calculates the embeddings of input texts and then merge with the image embeddings
|
|
@@ -267,9 +293,9 @@ def get_flax_model(
|
|
|
267
293
|
jax.jit,
|
|
268
294
|
out_shardings=(embed_sharding),
|
|
269
295
|
)
|
|
270
|
-
def
|
|
296
|
+
def run_embed_input_ids(graphdef, state, *args, **kwargs):
|
|
271
297
|
model = nnx.merge(graphdef, state)
|
|
272
|
-
return model.
|
|
298
|
+
return model.embed_input_ids(*args, **kwargs)
|
|
273
299
|
|
|
274
300
|
# For models that want to work with EAGLE-3 speculative decoding
|
|
275
301
|
@functools.partial(
|
|
@@ -285,10 +311,8 @@ def get_flax_model(
|
|
|
285
311
|
None)
|
|
286
312
|
model_fn = functools.partial(run_model, graphdef)
|
|
287
313
|
compute_logits_fn = functools.partial(run_compute_logits, graphdef)
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
|
|
291
|
-
graphdef)
|
|
314
|
+
embed_multimodal_fn = functools.partial(run_embed_multimodal, graphdef)
|
|
315
|
+
embed_input_ids_fn = functools.partial(run_embed_input_ids, graphdef)
|
|
292
316
|
lora_manager, model = None, None
|
|
293
317
|
combine_hidden_states_fn = functools.partial(combine_hidden_states,
|
|
294
318
|
graphdef)
|
|
@@ -299,8 +323,8 @@ def get_flax_model(
|
|
|
299
323
|
|
|
300
324
|
multimodal_fns = {
|
|
301
325
|
"precompile_vision_encoder_fn": precompile_vision_encoder_fn,
|
|
302
|
-
"
|
|
303
|
-
"
|
|
326
|
+
"embed_multimodal_fn": embed_multimodal_fn,
|
|
327
|
+
"embed_input_ids_fn": embed_input_ids_fn,
|
|
304
328
|
"get_mrope_input_positions_fn": get_mrope_input_positions_fn,
|
|
305
329
|
}
|
|
306
330
|
|
|
@@ -312,6 +336,8 @@ def get_vllm_model(
|
|
|
312
336
|
rng: jax.Array,
|
|
313
337
|
mesh: Mesh,
|
|
314
338
|
):
|
|
339
|
+
model_dtype = to_torch_dtype(vllm_config.model_config.dtype)
|
|
340
|
+
vllm_config.model_config.dtype = model_dtype
|
|
315
341
|
from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper
|
|
316
342
|
|
|
317
343
|
model = VllmModelWrapper(
|
|
@@ -337,24 +363,39 @@ def get_model(
|
|
|
337
363
|
impl = envs.MODEL_IMPL_TYPE
|
|
338
364
|
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
|
|
339
365
|
|
|
340
|
-
if impl == "
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
366
|
+
if impl == "auto":
|
|
367
|
+
# Resolve "auto" based on architecture
|
|
368
|
+
architectures = getattr(vllm_config.model_config.hf_config,
|
|
369
|
+
"architectures", [])
|
|
370
|
+
assert len(architectures) == 1, (
|
|
371
|
+
f"Expected exactly one architecture, got {len(architectures)}: "
|
|
372
|
+
f"{architectures}")
|
|
373
|
+
arch = architectures[0]
|
|
374
|
+
impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
|
|
375
|
+
logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
|
|
376
|
+
|
|
377
|
+
match impl:
|
|
378
|
+
case "flax_nnx":
|
|
379
|
+
if vllm_config.parallel_config.pipeline_parallel_size > 1:
|
|
380
|
+
logger.warning(
|
|
381
|
+
"PP is not fully supported on Jax flax_nnx models yet, fallback to vllm models."
|
|
382
|
+
)
|
|
383
|
+
return get_vllm_model(vllm_config, rng, mesh)
|
|
384
|
+
try:
|
|
385
|
+
# Try to load the flax model first
|
|
386
|
+
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
|
|
387
|
+
except UnsupportedArchitectureError as e:
|
|
388
|
+
# Convert the error message to a string to check its contents
|
|
389
|
+
error_msg = str(e)
|
|
390
|
+
|
|
391
|
+
logger.warning(error_msg)
|
|
392
|
+
|
|
393
|
+
# Fall back to the vLLM model and updating the dtype accordingly
|
|
394
|
+
return get_vllm_model(vllm_config, rng, mesh)
|
|
395
|
+
case "vllm":
|
|
353
396
|
return get_vllm_model(vllm_config, rng, mesh)
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
else:
|
|
357
|
-
raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
|
|
397
|
+
case _:
|
|
398
|
+
raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
|
|
358
399
|
|
|
359
400
|
|
|
360
401
|
def _validate_model_interface(model: Any) -> None:
|
|
@@ -441,14 +482,14 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
441
482
|
)
|
|
442
483
|
|
|
443
484
|
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
|
|
444
|
-
def
|
|
485
|
+
def unimplemented_embed_input_ids(
|
|
445
486
|
self,
|
|
446
487
|
input_ids: "torch.Tensor",
|
|
447
488
|
positions: "torch.Tensor",
|
|
448
489
|
inputs_embeds: Optional["torch.Tensor"] = None,
|
|
449
490
|
) -> "torch.Tensor":
|
|
450
491
|
raise NotImplementedError(
|
|
451
|
-
"This is a JAX model and does not implement the PyTorch
|
|
492
|
+
"This is a JAX model and does not implement the PyTorch embed_input_ids method."
|
|
452
493
|
)
|
|
453
494
|
|
|
454
495
|
# We need a custom __init__ that only calls torch.nn.Module's init,
|
|
@@ -464,7 +505,7 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
464
505
|
{
|
|
465
506
|
"__init__": wrapper_init,
|
|
466
507
|
"forward": unimplemented_forward,
|
|
467
|
-
"
|
|
508
|
+
"embed_input_ids": unimplemented_embed_input_ids,
|
|
468
509
|
# Prevent vLLM from trying to load weights into this dummy class.
|
|
469
510
|
"load_weights": lambda self, *args, **kwargs: None,
|
|
470
511
|
})
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|