tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
from typing import Any, Optional
|
|
3
17
|
|
|
@@ -5,7 +19,6 @@ import jax
|
|
|
5
19
|
import torch
|
|
6
20
|
from flax import nnx
|
|
7
21
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
-
from torchax.ops.mappings import j2t_dtype
|
|
9
22
|
from transformers import PretrainedConfig
|
|
10
23
|
from vllm.config import VllmConfig
|
|
11
24
|
from vllm.model_executor.model_loader import get_model_loader
|
|
@@ -16,14 +29,20 @@ from vllm.utils.func_utils import supports_kw
|
|
|
16
29
|
from tpu_inference import envs
|
|
17
30
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
18
31
|
from tpu_inference.logger import init_logger
|
|
19
|
-
from tpu_inference.models.jax.utils.
|
|
32
|
+
from tpu_inference.models.jax.utils.qwix.qwix_utils import (
|
|
20
33
|
apply_qwix_on_abstract_model, apply_qwix_quantization,
|
|
21
34
|
load_random_weights_into_qwix_abstract_model)
|
|
35
|
+
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
|
|
22
36
|
|
|
23
37
|
logger = init_logger(__name__)
|
|
24
38
|
|
|
25
39
|
_MODEL_REGISTRY = {}
|
|
26
40
|
|
|
41
|
+
# List of architectures that are preferred to use "vllm" implementation over
|
|
42
|
+
# "flax_nnx" implementation due to various factors such as performance.
|
|
43
|
+
_VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
|
|
44
|
+
{"GptOssForCausalLM"})
|
|
45
|
+
|
|
27
46
|
|
|
28
47
|
class UnsupportedArchitectureError(ValueError):
|
|
29
48
|
"""Raised when a model architecture is not supported in the registry."""
|
|
@@ -210,6 +229,9 @@ def get_flax_model(
|
|
|
210
229
|
mesh: Mesh,
|
|
211
230
|
is_draft_model: bool = False,
|
|
212
231
|
) -> nnx.Module:
|
|
232
|
+
model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
|
|
233
|
+
vllm_config.model_config.dtype = model_dtype
|
|
234
|
+
|
|
213
235
|
if is_draft_model:
|
|
214
236
|
model_class = _get_model_architecture(
|
|
215
237
|
vllm_config.speculative_config.draft_model_config.hf_config)
|
|
@@ -218,7 +240,9 @@ def get_flax_model(
|
|
|
218
240
|
vllm_config.model_config.hf_config)
|
|
219
241
|
jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
|
|
220
242
|
kv_cache_sharding = NamedSharding(
|
|
221
|
-
mesh,
|
|
243
|
+
mesh,
|
|
244
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
|
|
245
|
+
ShardingAxisName.ATTN_HEAD))
|
|
222
246
|
hidden_states_sharding = NamedSharding(mesh,
|
|
223
247
|
PartitionSpec(
|
|
224
248
|
ShardingAxisName.ATTN_DATA,
|
|
@@ -245,7 +269,8 @@ def get_flax_model(
|
|
|
245
269
|
return model(*args)
|
|
246
270
|
|
|
247
271
|
logits_sharding = NamedSharding(
|
|
248
|
-
mesh,
|
|
272
|
+
mesh,
|
|
273
|
+
PartitionSpec(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR))
|
|
249
274
|
|
|
250
275
|
@functools.partial(
|
|
251
276
|
jax.jit,
|
|
@@ -258,10 +283,9 @@ def get_flax_model(
|
|
|
258
283
|
|
|
259
284
|
# Multi-modal support only
|
|
260
285
|
# This function calculates the image token's embeddings by VIT
|
|
261
|
-
def
|
|
262
|
-
**kwargs):
|
|
286
|
+
def run_embed_multimodal(graphdef, state, image_grid_thw, **kwargs):
|
|
263
287
|
model = nnx.merge(graphdef, state)
|
|
264
|
-
return model.
|
|
288
|
+
return model.embed_multimodal(image_grid_thw, **kwargs)
|
|
265
289
|
|
|
266
290
|
embed_sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
267
291
|
# This function will calculates the embeddings of input texts and then merge with the image embeddings
|
|
@@ -269,9 +293,9 @@ def get_flax_model(
|
|
|
269
293
|
jax.jit,
|
|
270
294
|
out_shardings=(embed_sharding),
|
|
271
295
|
)
|
|
272
|
-
def
|
|
296
|
+
def run_embed_input_ids(graphdef, state, *args, **kwargs):
|
|
273
297
|
model = nnx.merge(graphdef, state)
|
|
274
|
-
return model.
|
|
298
|
+
return model.embed_input_ids(*args, **kwargs)
|
|
275
299
|
|
|
276
300
|
# For models that want to work with EAGLE-3 speculative decoding
|
|
277
301
|
@functools.partial(
|
|
@@ -287,10 +311,8 @@ def get_flax_model(
|
|
|
287
311
|
None)
|
|
288
312
|
model_fn = functools.partial(run_model, graphdef)
|
|
289
313
|
compute_logits_fn = functools.partial(run_compute_logits, graphdef)
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
|
|
293
|
-
graphdef)
|
|
314
|
+
embed_multimodal_fn = functools.partial(run_embed_multimodal, graphdef)
|
|
315
|
+
embed_input_ids_fn = functools.partial(run_embed_input_ids, graphdef)
|
|
294
316
|
lora_manager, model = None, None
|
|
295
317
|
combine_hidden_states_fn = functools.partial(combine_hidden_states,
|
|
296
318
|
graphdef)
|
|
@@ -301,8 +323,8 @@ def get_flax_model(
|
|
|
301
323
|
|
|
302
324
|
multimodal_fns = {
|
|
303
325
|
"precompile_vision_encoder_fn": precompile_vision_encoder_fn,
|
|
304
|
-
"
|
|
305
|
-
"
|
|
326
|
+
"embed_multimodal_fn": embed_multimodal_fn,
|
|
327
|
+
"embed_input_ids_fn": embed_input_ids_fn,
|
|
306
328
|
"get_mrope_input_positions_fn": get_mrope_input_positions_fn,
|
|
307
329
|
}
|
|
308
330
|
|
|
@@ -314,6 +336,8 @@ def get_vllm_model(
|
|
|
314
336
|
rng: jax.Array,
|
|
315
337
|
mesh: Mesh,
|
|
316
338
|
):
|
|
339
|
+
model_dtype = to_torch_dtype(vllm_config.model_config.dtype)
|
|
340
|
+
vllm_config.model_config.dtype = model_dtype
|
|
317
341
|
from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper
|
|
318
342
|
|
|
319
343
|
model = VllmModelWrapper(
|
|
@@ -339,24 +363,39 @@ def get_model(
|
|
|
339
363
|
impl = envs.MODEL_IMPL_TYPE
|
|
340
364
|
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
|
|
341
365
|
|
|
342
|
-
if impl == "
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
366
|
+
if impl == "auto":
|
|
367
|
+
# Resolve "auto" based on architecture
|
|
368
|
+
architectures = getattr(vllm_config.model_config.hf_config,
|
|
369
|
+
"architectures", [])
|
|
370
|
+
assert len(architectures) == 1, (
|
|
371
|
+
f"Expected exactly one architecture, got {len(architectures)}: "
|
|
372
|
+
f"{architectures}")
|
|
373
|
+
arch = architectures[0]
|
|
374
|
+
impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
|
|
375
|
+
logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
|
|
376
|
+
|
|
377
|
+
match impl:
|
|
378
|
+
case "flax_nnx":
|
|
379
|
+
if vllm_config.parallel_config.pipeline_parallel_size > 1:
|
|
380
|
+
logger.warning(
|
|
381
|
+
"PP is not fully supported on Jax flax_nnx models yet, fallback to vllm models."
|
|
382
|
+
)
|
|
383
|
+
return get_vllm_model(vllm_config, rng, mesh)
|
|
384
|
+
try:
|
|
385
|
+
# Try to load the flax model first
|
|
386
|
+
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
|
|
387
|
+
except UnsupportedArchitectureError as e:
|
|
388
|
+
# Convert the error message to a string to check its contents
|
|
389
|
+
error_msg = str(e)
|
|
390
|
+
|
|
391
|
+
logger.warning(error_msg)
|
|
392
|
+
|
|
393
|
+
# Fall back to the vLLM model and updating the dtype accordingly
|
|
394
|
+
return get_vllm_model(vllm_config, rng, mesh)
|
|
395
|
+
case "vllm":
|
|
355
396
|
return get_vllm_model(vllm_config, rng, mesh)
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
else:
|
|
359
|
-
raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
|
|
397
|
+
case _:
|
|
398
|
+
raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
|
|
360
399
|
|
|
361
400
|
|
|
362
401
|
def _validate_model_interface(model: Any) -> None:
|
|
@@ -443,14 +482,14 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
443
482
|
)
|
|
444
483
|
|
|
445
484
|
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
|
|
446
|
-
def
|
|
485
|
+
def unimplemented_embed_input_ids(
|
|
447
486
|
self,
|
|
448
487
|
input_ids: "torch.Tensor",
|
|
449
488
|
positions: "torch.Tensor",
|
|
450
489
|
inputs_embeds: Optional["torch.Tensor"] = None,
|
|
451
490
|
) -> "torch.Tensor":
|
|
452
491
|
raise NotImplementedError(
|
|
453
|
-
"This is a JAX model and does not implement the PyTorch
|
|
492
|
+
"This is a JAX model and does not implement the PyTorch embed_input_ids method."
|
|
454
493
|
)
|
|
455
494
|
|
|
456
495
|
# We need a custom __init__ that only calls torch.nn.Module's init,
|
|
@@ -466,7 +505,7 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
466
505
|
{
|
|
467
506
|
"__init__": wrapper_init,
|
|
468
507
|
"forward": unimplemented_forward,
|
|
469
|
-
"
|
|
508
|
+
"embed_input_ids": unimplemented_embed_input_ids,
|
|
470
509
|
# Prevent vLLM from trying to load weights into this dummy class.
|
|
471
510
|
"load_weights": lambda self, *args, **kwargs: None,
|
|
472
511
|
})
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|