tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
import copy
|
|
3
4
|
import functools
|
|
4
5
|
import os
|
|
5
6
|
from typing import TYPE_CHECKING, Callable, List
|
|
@@ -41,10 +42,36 @@ DEFAULT_DEEPSEEK_FP8_CONFIG = {
|
|
|
41
42
|
"scale_dtype":
|
|
42
43
|
"bfloat16",
|
|
43
44
|
"rules": [
|
|
45
|
+
# Exclude router from quantization
|
|
44
46
|
{
|
|
45
47
|
"module_path": ".*.custom_module.router.*",
|
|
46
48
|
"weight_qtype": None,
|
|
47
49
|
},
|
|
50
|
+
# Avoid the combine expert ops
|
|
51
|
+
{
|
|
52
|
+
"module_path": ".*combine_experts.*",
|
|
53
|
+
"weight_qtype": None,
|
|
54
|
+
},
|
|
55
|
+
# Attention layers: keep FP8 for weights and activations
|
|
56
|
+
{
|
|
57
|
+
"module_path": ".*.attn.*",
|
|
58
|
+
"weight_qtype": "float8_e4m3fn",
|
|
59
|
+
"act_qtype": "float8_e4m3fn",
|
|
60
|
+
},
|
|
61
|
+
# MoE experts: use FP4 for expert weights
|
|
62
|
+
{
|
|
63
|
+
"module_path": ".*.custom_module.*",
|
|
64
|
+
"weight_qtype": "float4_e2m1fn",
|
|
65
|
+
"act_qtype": "float8_e4m3fn",
|
|
66
|
+
"tile_size": 256,
|
|
67
|
+
},
|
|
68
|
+
# Shared experts: also FP4
|
|
69
|
+
{
|
|
70
|
+
"module_path": ".*.shared_experts.*",
|
|
71
|
+
"weight_qtype": "float4_e2m1fn",
|
|
72
|
+
"act_qtype": "float8_e4m3fn",
|
|
73
|
+
"tile_size": 256,
|
|
74
|
+
},
|
|
48
75
|
{
|
|
49
76
|
"module_path": ".*",
|
|
50
77
|
"weight_qtype": "float8_e4m3fn",
|
|
@@ -166,9 +193,11 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
|
|
|
166
193
|
head_size=kv_cache_head_size,
|
|
167
194
|
mesh=mesh,
|
|
168
195
|
layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
|
|
169
|
-
cache_dtype=kv_cache_jnp_dtype
|
|
196
|
+
cache_dtype=kv_cache_jnp_dtype,
|
|
197
|
+
use_mla=model.vllm_config.model_config.use_mla,
|
|
198
|
+
)
|
|
170
199
|
|
|
171
|
-
dp_size =
|
|
200
|
+
dp_size = model.vllm_config.sharding_config.total_dp_size
|
|
172
201
|
|
|
173
202
|
# NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
|
|
174
203
|
input_ids = jax.random.randint(rng,
|
|
@@ -396,8 +425,7 @@ def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
|
|
|
396
425
|
|
|
397
426
|
|
|
398
427
|
def get_default_qwix_quantization_config(
|
|
399
|
-
|
|
400
|
-
skip_quantization: bool) -> dict | None:
|
|
428
|
+
hf_config: dict, skip_quantization: bool) -> dict | None:
|
|
401
429
|
"""
|
|
402
430
|
Some models are pre-quantized and in those cases, we want to return a default set of
|
|
403
431
|
Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
|
|
@@ -415,9 +443,42 @@ def get_default_qwix_quantization_config(
|
|
|
415
443
|
"""
|
|
416
444
|
if skip_quantization:
|
|
417
445
|
return None
|
|
418
|
-
|
|
446
|
+
model_type = hf_config.model_type.lower() if hasattr(
|
|
447
|
+
hf_config, "model_type") else None
|
|
448
|
+
quant_method = hf_config.quantization_config["quant_method"] if hasattr(
|
|
449
|
+
hf_config, "quantization_config") else None
|
|
450
|
+
# TODO (jacobplatin): remove this so that we can support various quantization types + make
|
|
451
|
+
# more flexible
|
|
452
|
+
# NOTE (jacobplatin): we'll default to mixed FP8 (attention) + FP4 (MoE experts)
|
|
453
|
+
# for DeepSeek
|
|
419
454
|
if model_type == "deepseek_v3" and quant_method == "fp8":
|
|
420
|
-
|
|
455
|
+
config = copy.deepcopy(DEFAULT_DEEPSEEK_FP8_CONFIG)
|
|
456
|
+
|
|
457
|
+
# Dynamically fetch block size from HF config if available
|
|
458
|
+
# Config fmt: 'weight_block_size': [1, 512] -> we want the 2nd dim for tile_size
|
|
459
|
+
# NOTE: if the checkpoint is not 1D subchannel, we will throw an error
|
|
460
|
+
hf_quant_config = hf_config.quantization_config
|
|
461
|
+
assert "weight_block_size" in hf_quant_config, "Expected weight_block_size in quantization_config"
|
|
462
|
+
block_size = hf_quant_config["weight_block_size"]
|
|
463
|
+
if isinstance(block_size, (list, tuple)) and len(block_size) == 2:
|
|
464
|
+
assert block_size[
|
|
465
|
+
0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}!"
|
|
466
|
+
tile_size = block_size[1]
|
|
467
|
+
assert tile_size > 1, f"Expected tile_size > 1 for DeepSeek, but got {tile_size}"
|
|
468
|
+
logger.info(
|
|
469
|
+
f"Detected DeepSeek tile_size from config: {tile_size}")
|
|
470
|
+
|
|
471
|
+
# Update tile_size in the rules, since we might not always use a 1D subchannel size of
|
|
472
|
+
# 256
|
|
473
|
+
for rule in config["qwix"]["rules"]:
|
|
474
|
+
if "tile_size" in rule:
|
|
475
|
+
rule["tile_size"] = tile_size
|
|
476
|
+
else:
|
|
477
|
+
raise ValueError(
|
|
478
|
+
f"Invalid weight_block_size config: {block_size}, expected a list/tuple of length 2"
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
return config
|
|
421
482
|
elif model_type == "llama4" and quant_method == "compressed-tensors":
|
|
422
483
|
return DEFAULT_LLAMA4_FP8_CONFIG
|
|
423
484
|
# MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
|
|
@@ -436,14 +497,10 @@ def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
|
|
|
436
497
|
# Qwix quantization config accordingly
|
|
437
498
|
# NOTE: if a Qwix config is provided (via the`additional_config`), we'll
|
|
438
499
|
# use that instead
|
|
439
|
-
|
|
440
|
-
) if hasattr(vllm_config.model_config.hf_config, "model_type") else None
|
|
441
|
-
quant_method = vllm_config.model_config.hf_config.quantization_config[
|
|
442
|
-
"quant_method"] if hasattr(vllm_config.model_config.hf_config,
|
|
443
|
-
"quantization_config") else None
|
|
500
|
+
hf_config = vllm_config.model_config.hf_config
|
|
444
501
|
default_quantization_config = get_default_qwix_quantization_config(
|
|
445
|
-
|
|
446
|
-
|
|
502
|
+
hf_config, vllm_config.additional_config.get("skip_quantization",
|
|
503
|
+
False))
|
|
447
504
|
|
|
448
505
|
maybe_existing_quantization_config = vllm_config.additional_config.get(
|
|
449
506
|
"quantization")
|
|
@@ -500,7 +557,14 @@ def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
|
|
|
500
557
|
maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
|
|
501
558
|
weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
|
|
502
559
|
else:
|
|
503
|
-
|
|
560
|
+
# NOTE: _uniform() in random.py does not accept float4_e2m1fn
|
|
561
|
+
# Error: "TypeError: uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot float4_e2m1fn."
|
|
562
|
+
# Workaround: call function with dtype jnp.float8_e4m3fn and cast back to float4_e2m1fn
|
|
563
|
+
if dtype != "float4_e2m1fn":
|
|
564
|
+
weight = jax.random.normal(key, param_shape, dtype)
|
|
565
|
+
else:
|
|
566
|
+
weight = jax.random.normal(key, param_shape,
|
|
567
|
+
jnp.float8_e4m3fn).astype(dtype)
|
|
504
568
|
|
|
505
569
|
def get_slice(index):
|
|
506
570
|
return weight[index]
|
|
@@ -535,18 +599,16 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
|
|
|
535
599
|
logger.info("Initializing Qwix-quantized model with random weights...")
|
|
536
600
|
# TODO (jacobplatin): clean up this logic
|
|
537
601
|
scale_dtype = model.weight_loader.scale_dtype
|
|
538
|
-
scale_shape_map = model.weight_loader.
|
|
602
|
+
scale_shape_map = model.weight_loader.scale_shape_map_for_random_weight_loading if hasattr(
|
|
539
603
|
model.weight_loader,
|
|
540
|
-
'
|
|
604
|
+
'scale_shape_map_for_random_weight_loading') else {}
|
|
541
605
|
quantization_block_sizes = quantization_config["weight_block_size"]
|
|
542
606
|
assert len(
|
|
543
607
|
quantization_block_sizes
|
|
544
608
|
) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
|
|
545
|
-
quantization_block_size_n, _ = quantization_block_sizes[
|
|
546
|
-
0], quantization_block_sizes[1]
|
|
547
609
|
|
|
548
610
|
# Iterate through all variables and initialize them
|
|
549
|
-
|
|
611
|
+
|
|
550
612
|
for path, param in nnx.iter_graph(model):
|
|
551
613
|
if not isinstance(param, nnx.Variable):
|
|
552
614
|
continue
|
|
@@ -556,16 +618,17 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
|
|
|
556
618
|
is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
|
|
557
619
|
param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
|
|
558
620
|
param_shape = param.value.shape
|
|
559
|
-
# TODO (jacobplatin): clean this up
|
|
560
621
|
if is_qwix_scale:
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
622
|
+
key = f"{path[2]}.{path[3]}"
|
|
623
|
+
|
|
624
|
+
if key in scale_shape_map:
|
|
625
|
+
param_shape = scale_shape_map[key]
|
|
626
|
+
else:
|
|
627
|
+
raise ValueError(
|
|
628
|
+
f"Scale shape for {key} not found in scale_shape_map.")
|
|
565
629
|
param.value = get_random_sharded_array(
|
|
566
630
|
rng, mesh, param, param_shape, param_dtype,
|
|
567
631
|
".".join([str(x) for x in path]))
|
|
568
|
-
prev_param_shape = param_shape
|
|
569
632
|
|
|
570
633
|
# Handles the DeepSeek case, where this needs to be called to make the cache weights
|
|
571
634
|
# concrete
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""Utilities for downloading model weights from HuggingFace."""
|
|
2
15
|
|
|
3
16
|
import functools
|
|
@@ -67,7 +80,13 @@ def transpose_params(param_key: str, param_tensor: jax.Array, transpose_map):
|
|
|
67
80
|
def reshape_params(param_key: str, param_tensor: jax.Array, shape_map):
|
|
68
81
|
for key, new_shape in shape_map.items():
|
|
69
82
|
if key in param_key:
|
|
70
|
-
|
|
83
|
+
try:
|
|
84
|
+
#TODO:(gpolovets) Add validation on whether reshape preserves data layout.
|
|
85
|
+
return jnp.reshape(param_tensor, new_shape)
|
|
86
|
+
except TypeError:
|
|
87
|
+
raise TypeError(
|
|
88
|
+
f"Cannot reshape for key={key}, new_shape={new_shape}, param_shape={param_tensor.shape}"
|
|
89
|
+
)
|
|
71
90
|
return param_tensor # Base case / no-op
|
|
72
91
|
|
|
73
92
|
|
|
@@ -275,7 +294,8 @@ def _load_and_shard_weight(vllm_config,
|
|
|
275
294
|
hf_key: str,
|
|
276
295
|
hf_weight: jax.Array,
|
|
277
296
|
keep_original_dtype_keys_regex: list[str]
|
|
278
|
-
| None = None
|
|
297
|
+
| None = None,
|
|
298
|
+
pp_missing_layers: list[str] | None = None):
|
|
279
299
|
name_map = metadata_map.name_map
|
|
280
300
|
reshape_keys = metadata_map.reshape_map
|
|
281
301
|
bias_reshape_keys = metadata_map.bias_reshape_map
|
|
@@ -331,6 +351,10 @@ def _load_and_shard_weight(vllm_config,
|
|
|
331
351
|
return
|
|
332
352
|
model_key = name_map.get(hf_key, hf_key)
|
|
333
353
|
|
|
354
|
+
if pp_missing_layers and _is_pp_missing_layer(hf_key, pp_missing_layers):
|
|
355
|
+
logger.warning(
|
|
356
|
+
f"Skip loading {hf_key} as it doesn't belong to this PP stage.")
|
|
357
|
+
return
|
|
334
358
|
model_weight, model_sharding = get_param_and_sharding(
|
|
335
359
|
params, shardings, model_key)
|
|
336
360
|
|
|
@@ -394,6 +418,14 @@ def _load_and_shard_weight(vllm_config,
|
|
|
394
418
|
model_weight.value = shard(hf_weight, spec)
|
|
395
419
|
|
|
396
420
|
|
|
421
|
+
def _is_pp_missing_layer(hf_key: str, pp_missing_layers: list[str]) -> bool:
|
|
422
|
+
has_digit = any(char.isdigit() for char in hf_key)
|
|
423
|
+
# add the suffix after digits to avoid it matches "layers.10" with "layers.1"
|
|
424
|
+
suffix = "." if has_digit else ""
|
|
425
|
+
return any(f'{pp_missing_layer}{suffix}' in hf_key
|
|
426
|
+
for pp_missing_layer in pp_missing_layers)
|
|
427
|
+
|
|
428
|
+
|
|
397
429
|
def _load_hf_weights_on_thread(
|
|
398
430
|
vllm_config: VllmConfig,
|
|
399
431
|
params: nnx.State,
|
|
@@ -402,6 +434,7 @@ def _load_hf_weights_on_thread(
|
|
|
402
434
|
weights_file: str,
|
|
403
435
|
filter_regex: Optional[str] = None,
|
|
404
436
|
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
437
|
+
pp_missing_layers: list[str] | None = None,
|
|
405
438
|
):
|
|
406
439
|
"""Loads weights from a single weights file."""
|
|
407
440
|
try:
|
|
@@ -420,6 +453,7 @@ def _load_hf_weights_on_thread(
|
|
|
420
453
|
hf_key,
|
|
421
454
|
hf_weight,
|
|
422
455
|
keep_original_dtype_keys_regex,
|
|
456
|
+
pp_missing_layers,
|
|
423
457
|
)
|
|
424
458
|
|
|
425
459
|
|
|
@@ -431,6 +465,7 @@ def load_hf_weights(
|
|
|
431
465
|
filter_regex: Optional[str] = None,
|
|
432
466
|
is_draft_model: bool = False,
|
|
433
467
|
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
468
|
+
pp_missing_layers: list[str] | None = None,
|
|
434
469
|
):
|
|
435
470
|
"""Load weights into a JAX model from either an iterator or files."""
|
|
436
471
|
params = nnx.state(model)
|
|
@@ -461,6 +496,7 @@ def load_hf_weights(
|
|
|
461
496
|
hf_key,
|
|
462
497
|
hf_weight_jax,
|
|
463
498
|
keep_original_dtype_keys_regex,
|
|
499
|
+
pp_missing_layers=pp_missing_layers,
|
|
464
500
|
)
|
|
465
501
|
else:
|
|
466
502
|
# File-based path (multi-threaded)
|
|
@@ -488,6 +524,7 @@ def load_hf_weights(
|
|
|
488
524
|
filter_regex=filter_regex,
|
|
489
525
|
keep_original_dtype_keys_regex=
|
|
490
526
|
keep_original_dtype_keys_regex,
|
|
527
|
+
pp_missing_layers=pp_missing_layers,
|
|
491
528
|
) for weights_file in weights_files
|
|
492
529
|
]
|
|
493
530
|
for future in futures:
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import copy
|
|
2
16
|
import functools
|
|
3
17
|
from collections.abc import Sequence
|
|
@@ -23,6 +37,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
|
|
|
23
37
|
from vllm.sequence import IntermediateTensors
|
|
24
38
|
|
|
25
39
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
26
41
|
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
27
42
|
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
|
|
28
43
|
from tpu_inference.logger import init_logger
|
|
@@ -197,7 +212,7 @@ class VllmModelWrapper:
|
|
|
197
212
|
kwargs={
|
|
198
213
|
"input_ids": torch_view(input_ids),
|
|
199
214
|
"positions": torch_view(input_positions),
|
|
200
|
-
"intermediate_tensors":
|
|
215
|
+
"intermediate_tensors": intermediate_tensors,
|
|
201
216
|
"inputs_embeds": None,
|
|
202
217
|
},
|
|
203
218
|
tie_weights=False,
|
|
@@ -220,8 +235,10 @@ class VllmModelWrapper:
|
|
|
220
235
|
|
|
221
236
|
@functools.partial(
|
|
222
237
|
jax.jit,
|
|
223
|
-
out_shardings=(NamedSharding(
|
|
224
|
-
|
|
238
|
+
out_shardings=(NamedSharding(
|
|
239
|
+
self.mesh,
|
|
240
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
241
|
+
ShardingAxisName.MLP_TENSOR))),
|
|
225
242
|
)
|
|
226
243
|
def compute_logits_func(
|
|
227
244
|
params_and_buffers: Any,
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from contextlib import contextmanager
|
|
2
16
|
from dataclasses import dataclass
|
|
3
17
|
from typing import Dict, List, Optional
|
|
@@ -1,2 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
# ruff: noqa
|
|
2
16
|
from tpu_inference.platforms.tpu_platform import TpuPlatform
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
from typing import TYPE_CHECKING,
|
|
3
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
|
4
4
|
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
import torch
|
|
@@ -8,23 +8,25 @@ import vllm.envs as vllm_envs
|
|
|
8
8
|
from tpu_info import device
|
|
9
9
|
from vllm.inputs import ProcessorInputs, PromptType
|
|
10
10
|
from vllm.platforms.interface import Platform, PlatformEnum
|
|
11
|
-
from vllm.sampling_params import SamplingParams, SamplingType
|
|
12
11
|
|
|
13
12
|
from tpu_inference import envs
|
|
14
13
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
15
14
|
from tpu_inference.logger import init_logger
|
|
16
|
-
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
|
|
17
15
|
|
|
18
16
|
if TYPE_CHECKING:
|
|
19
|
-
from vllm.attention.backends.registry import
|
|
17
|
+
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
18
|
+
from vllm.attention.selector import AttentionSelectorConfig
|
|
20
19
|
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
|
21
20
|
from vllm.pooling_params import PoolingParams
|
|
21
|
+
from vllm.sampling_params import SamplingParams, SamplingType
|
|
22
22
|
else:
|
|
23
23
|
BlockSize = None
|
|
24
24
|
ModelConfig = None
|
|
25
25
|
VllmConfig = None
|
|
26
26
|
PoolingParams = None
|
|
27
|
-
|
|
27
|
+
AttentionBackendEnum = None
|
|
28
|
+
SamplingParams = None
|
|
29
|
+
SamplingType = None
|
|
28
30
|
|
|
29
31
|
logger = init_logger(__name__)
|
|
30
32
|
|
|
@@ -44,25 +46,21 @@ class TpuPlatform(Platform):
|
|
|
44
46
|
|
|
45
47
|
additional_env_vars: list[str] = [
|
|
46
48
|
"PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
|
|
47
|
-
"TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE"
|
|
49
|
+
"TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE",
|
|
50
|
+
"NEW_MODEL_DESIGN"
|
|
48
51
|
]
|
|
49
52
|
|
|
50
53
|
@classmethod
|
|
51
|
-
def get_attn_backend_cls(cls, selected_backend: "
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
if selected_backend != _Backend.PALLAS:
|
|
54
|
+
def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
|
|
55
|
+
attn_selector_config: "AttentionSelectorConfig",
|
|
56
|
+
**kwargs) -> str:
|
|
57
|
+
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
58
|
+
|
|
59
|
+
if selected_backend != AttentionBackendEnum.PALLAS:
|
|
58
60
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
|
59
61
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
|
|
63
|
-
else:
|
|
64
|
-
logger.info("Using Pallas backend.")
|
|
65
|
-
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
|
62
|
+
logger.info("Using Pallas V1 backend.")
|
|
63
|
+
return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
|
|
66
64
|
|
|
67
65
|
@classmethod
|
|
68
66
|
def get_device_name(cls, device_id: int = 0) -> str:
|
|
@@ -146,39 +144,21 @@ class TpuPlatform(Platform):
|
|
|
146
144
|
if compilation_config.backend == "":
|
|
147
145
|
compilation_config.backend = "openxla"
|
|
148
146
|
|
|
149
|
-
# If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
|
|
150
|
-
impl = envs.MODEL_IMPL_TYPE
|
|
151
|
-
|
|
152
|
-
# NOTE(xiang): convert dtype to jnp.dtype
|
|
153
|
-
# NOTE(wenlong): skip this logic for mm model preprocessing
|
|
154
|
-
# For mm model preprocessors, it may need the output dtype to be torch.
|
|
155
|
-
# In order to avoid a PR to vLLM, we postpone the dtype checking during
|
|
156
|
-
# tpu_worker initialization
|
|
157
|
-
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
|
|
158
|
-
model_dtype = vllm_config.model_config.dtype
|
|
159
|
-
try:
|
|
160
|
-
dtype = to_jax_dtype(model_dtype)
|
|
161
|
-
except ValueError:
|
|
162
|
-
logger.warning(f"{model_dtype=} is not supported. "
|
|
163
|
-
"Falling back to jnp.bfloat16")
|
|
164
|
-
dtype = jnp.bfloat16
|
|
165
|
-
if impl == "vllm":
|
|
166
|
-
dtype = to_torch_dtype(dtype)
|
|
167
|
-
vllm_config.model_config.dtype = dtype
|
|
168
|
-
|
|
169
147
|
# TODO(cuiq): remove this dependency.
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
148
|
+
if vllm_config.model_config:
|
|
149
|
+
from vllm.v1.attention.backends.pallas import \
|
|
150
|
+
PallasAttentionBackend
|
|
151
|
+
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
|
152
|
+
vllm_config) # type: ignore[assignment]
|
|
153
|
+
min_page_size = PallasAttentionBackend.get_min_page_size(
|
|
154
|
+
vllm_config)
|
|
155
|
+
if min_page_size > cache_config.block_size:
|
|
156
|
+
logger.warning(
|
|
157
|
+
"Increase the page size from %s to %s to avoid SMEM OOM",
|
|
158
|
+
cache_config.block_size,
|
|
159
|
+
min_page_size,
|
|
160
|
+
)
|
|
161
|
+
cache_config.block_size = min_page_size # type: ignore[assignment]
|
|
182
162
|
|
|
183
163
|
parallel_config = vllm_config.parallel_config
|
|
184
164
|
scheduler_config = vllm_config.scheduler_config
|
|
@@ -188,12 +168,12 @@ class TpuPlatform(Platform):
|
|
|
188
168
|
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
189
169
|
if not multihost_backend: # Single host
|
|
190
170
|
if parallel_config.pipeline_parallel_size == 1:
|
|
191
|
-
logger.info("Force using UniProcExecutor for JAX on
|
|
192
|
-
|
|
171
|
+
logger.info("Force using UniProcExecutor for JAX on "
|
|
172
|
+
"single host without pipeline parallelism.")
|
|
193
173
|
parallel_config.distributed_executor_backend = "uni"
|
|
194
174
|
else:
|
|
195
|
-
logger.info("Force using MultiprocExecutor for JAX on
|
|
196
|
-
|
|
175
|
+
logger.info("Force using MultiprocExecutor for JAX on "
|
|
176
|
+
"single host with pipeline parallelism.")
|
|
197
177
|
parallel_config.distributed_executor_backend = "mp"
|
|
198
178
|
elif multihost_backend == "ray":
|
|
199
179
|
from tpu_inference.executors.ray_distributed_executor import \
|
|
@@ -209,19 +189,21 @@ class TpuPlatform(Platform):
|
|
|
209
189
|
|
|
210
190
|
if scheduler_config.is_multimodal_model and not \
|
|
211
191
|
scheduler_config.disable_chunked_mm_input:
|
|
212
|
-
logger.warning("TPU does not support running Multimodal models"
|
|
213
|
-
|
|
214
|
-
|
|
192
|
+
logger.warning("TPU does not support running Multimodal models"
|
|
193
|
+
" without setting `--disable_chunked_mm_input`. "
|
|
194
|
+
"Forcing --disable_chunked_mm_input.")
|
|
215
195
|
scheduler_config.disable_chunked_mm_input = True
|
|
216
196
|
|
|
217
197
|
kv_transfer_config = vllm_config.kv_transfer_config
|
|
218
198
|
if kv_transfer_config is not None:
|
|
219
199
|
assert kv_transfer_config.kv_connector == "TPUConnector"
|
|
220
|
-
# Late initialization to avoid circular import
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
200
|
+
# Late initialization to avoid circular import.
|
|
201
|
+
# Only perform qwix quantization if it is jax model.
|
|
202
|
+
if vllm_config.model_config is not None:
|
|
203
|
+
from tpu_inference.models.jax.utils.qwix.qwix_utils import \
|
|
204
|
+
update_vllm_config_for_qwix_quantization
|
|
205
|
+
if vllm_config.model_config:
|
|
206
|
+
update_vllm_config_for_qwix_quantization(vllm_config)
|
|
225
207
|
|
|
226
208
|
from tpu_inference.core.sched.dp_scheduler import \
|
|
227
209
|
update_vllm_config_for_dp_scheduler
|
|
@@ -249,10 +231,11 @@ class TpuPlatform(Platform):
|
|
|
249
231
|
def validate_request(
|
|
250
232
|
cls,
|
|
251
233
|
prompt: PromptType,
|
|
252
|
-
params: Union[SamplingParams, PoolingParams],
|
|
234
|
+
params: Union["SamplingParams", PoolingParams],
|
|
253
235
|
processed_inputs: ProcessorInputs,
|
|
254
236
|
) -> None:
|
|
255
237
|
"""Raises if this request is unsupported on this platform"""
|
|
238
|
+
from vllm.sampling_params import SamplingParams, SamplingType
|
|
256
239
|
|
|
257
240
|
if isinstance(params, SamplingParams):
|
|
258
241
|
if params.sampling_type == SamplingType.RANDOM_SEED:
|
tpu_inference/runner/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|