tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
import copy
|
|
3
4
|
import functools
|
|
4
5
|
import os
|
|
5
6
|
from typing import TYPE_CHECKING, Callable, List
|
|
@@ -34,17 +35,43 @@ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
|
|
|
34
35
|
DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
|
|
35
36
|
DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
|
|
36
37
|
|
|
37
|
-
|
|
38
|
+
DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG = {
|
|
38
39
|
"qwix": {
|
|
39
40
|
"use_abstract_model":
|
|
40
41
|
True,
|
|
41
42
|
"scale_dtype":
|
|
42
43
|
"bfloat16",
|
|
43
44
|
"rules": [
|
|
45
|
+
# Exclude router from quantization
|
|
44
46
|
{
|
|
45
47
|
"module_path": ".*.custom_module.router.*",
|
|
46
48
|
"weight_qtype": None,
|
|
47
49
|
},
|
|
50
|
+
# Avoid the combine expert ops
|
|
51
|
+
{
|
|
52
|
+
"module_path": ".*combine_experts.*",
|
|
53
|
+
"weight_qtype": None,
|
|
54
|
+
},
|
|
55
|
+
# Attention layers: keep FP8 for weights and activations
|
|
56
|
+
{
|
|
57
|
+
"module_path": ".*.attn.*",
|
|
58
|
+
"weight_qtype": "float8_e4m3fn",
|
|
59
|
+
"act_qtype": "float8_e4m3fn",
|
|
60
|
+
},
|
|
61
|
+
# MoE experts: use FP4 for expert weights
|
|
62
|
+
{
|
|
63
|
+
"module_path": ".*.custom_module.*",
|
|
64
|
+
"weight_qtype": "float4_e2m1fn",
|
|
65
|
+
"act_qtype": "float8_e4m3fn",
|
|
66
|
+
"tile_size": 256,
|
|
67
|
+
},
|
|
68
|
+
# Shared experts: also FP4
|
|
69
|
+
{
|
|
70
|
+
"module_path": ".*.shared_experts.*",
|
|
71
|
+
"weight_qtype": "float4_e2m1fn",
|
|
72
|
+
"act_qtype": "float8_e4m3fn",
|
|
73
|
+
"tile_size": 256,
|
|
74
|
+
},
|
|
48
75
|
{
|
|
49
76
|
"module_path": ".*",
|
|
50
77
|
"weight_qtype": "float8_e4m3fn",
|
|
@@ -398,8 +425,7 @@ def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
|
|
|
398
425
|
|
|
399
426
|
|
|
400
427
|
def get_default_qwix_quantization_config(
|
|
401
|
-
|
|
402
|
-
skip_quantization: bool) -> dict | None:
|
|
428
|
+
hf_config: dict, skip_quantization: bool) -> dict | None:
|
|
403
429
|
"""
|
|
404
430
|
Some models are pre-quantized and in those cases, we want to return a default set of
|
|
405
431
|
Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
|
|
@@ -417,9 +443,42 @@ def get_default_qwix_quantization_config(
|
|
|
417
443
|
"""
|
|
418
444
|
if skip_quantization:
|
|
419
445
|
return None
|
|
420
|
-
|
|
446
|
+
model_type = hf_config.model_type.lower() if hasattr(
|
|
447
|
+
hf_config, "model_type") else None
|
|
448
|
+
quant_method = hf_config.quantization_config["quant_method"] if hasattr(
|
|
449
|
+
hf_config, "quantization_config") else None
|
|
450
|
+
# TODO (jacobplatin): remove this so that we can support various quantization types + make
|
|
451
|
+
# more flexible
|
|
452
|
+
# NOTE (jacobplatin): we'll default to mixed FP8 (attention) + FP4 (MoE experts)
|
|
453
|
+
# for DeepSeek
|
|
421
454
|
if model_type == "deepseek_v3" and quant_method == "fp8":
|
|
422
|
-
|
|
455
|
+
config = copy.deepcopy(DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG)
|
|
456
|
+
|
|
457
|
+
# Dynamically fetch block size from HF config if available
|
|
458
|
+
# Config fmt: 'weight_block_size': [1, 512] -> we want the 2nd dim for tile_size
|
|
459
|
+
# NOTE: if the checkpoint is not 1D subchannel, we will throw an error
|
|
460
|
+
hf_quant_config = hf_config.quantization_config
|
|
461
|
+
assert "weight_block_size" in hf_quant_config, "Expected weight_block_size in quantization_config"
|
|
462
|
+
block_size = hf_quant_config["weight_block_size"]
|
|
463
|
+
if isinstance(block_size, (list, tuple)) and len(block_size) == 2:
|
|
464
|
+
assert block_size[
|
|
465
|
+
0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}! If you are trying to run quantized DeepSeek, we currently only support 1D-subchannel quantization and those models can be found here: https://huggingface.co/collections/jrplatin/deepseek-r1-1d-subchannel"
|
|
466
|
+
tile_size = block_size[1]
|
|
467
|
+
assert tile_size > 1, f"Expected tile_size > 1 for DeepSeek, but got {tile_size}"
|
|
468
|
+
logger.info(
|
|
469
|
+
f"Detected DeepSeek tile_size from config: {tile_size}")
|
|
470
|
+
|
|
471
|
+
# Update tile_size in the rules, since we might not always use a 1D subchannel size of
|
|
472
|
+
# 256
|
|
473
|
+
for rule in config["qwix"]["rules"]:
|
|
474
|
+
if "tile_size" in rule:
|
|
475
|
+
rule["tile_size"] = tile_size
|
|
476
|
+
else:
|
|
477
|
+
raise ValueError(
|
|
478
|
+
f"Invalid weight_block_size config: {block_size}, expected a list/tuple of length 2"
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
return config
|
|
423
482
|
elif model_type == "llama4" and quant_method == "compressed-tensors":
|
|
424
483
|
return DEFAULT_LLAMA4_FP8_CONFIG
|
|
425
484
|
# MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
|
|
@@ -438,14 +497,10 @@ def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
|
|
|
438
497
|
# Qwix quantization config accordingly
|
|
439
498
|
# NOTE: if a Qwix config is provided (via the`additional_config`), we'll
|
|
440
499
|
# use that instead
|
|
441
|
-
|
|
442
|
-
) if hasattr(vllm_config.model_config.hf_config, "model_type") else None
|
|
443
|
-
quant_method = vllm_config.model_config.hf_config.quantization_config[
|
|
444
|
-
"quant_method"] if hasattr(vllm_config.model_config.hf_config,
|
|
445
|
-
"quantization_config") else None
|
|
500
|
+
hf_config = vllm_config.model_config.hf_config
|
|
446
501
|
default_quantization_config = get_default_qwix_quantization_config(
|
|
447
|
-
|
|
448
|
-
|
|
502
|
+
hf_config, vllm_config.additional_config.get("skip_quantization",
|
|
503
|
+
False))
|
|
449
504
|
|
|
450
505
|
maybe_existing_quantization_config = vllm_config.additional_config.get(
|
|
451
506
|
"quantization")
|
|
@@ -502,7 +557,14 @@ def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
|
|
|
502
557
|
maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
|
|
503
558
|
weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
|
|
504
559
|
else:
|
|
505
|
-
|
|
560
|
+
# NOTE: _uniform() in random.py does not accept float4_e2m1fn
|
|
561
|
+
# Error: "TypeError: uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot float4_e2m1fn."
|
|
562
|
+
# Workaround: call function with dtype jnp.float8_e4m3fn and cast back to float4_e2m1fn
|
|
563
|
+
if dtype != "float4_e2m1fn":
|
|
564
|
+
weight = jax.random.normal(key, param_shape, dtype)
|
|
565
|
+
else:
|
|
566
|
+
weight = jax.random.normal(key, param_shape,
|
|
567
|
+
jnp.float8_e4m3fn).astype(dtype)
|
|
506
568
|
|
|
507
569
|
def get_slice(index):
|
|
508
570
|
return weight[index]
|
|
@@ -537,18 +599,16 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
|
|
|
537
599
|
logger.info("Initializing Qwix-quantized model with random weights...")
|
|
538
600
|
# TODO (jacobplatin): clean up this logic
|
|
539
601
|
scale_dtype = model.weight_loader.scale_dtype
|
|
540
|
-
scale_shape_map = model.weight_loader.
|
|
602
|
+
scale_shape_map = model.weight_loader.scale_shape_map_for_random_weight_loading if hasattr(
|
|
541
603
|
model.weight_loader,
|
|
542
|
-
'
|
|
604
|
+
'scale_shape_map_for_random_weight_loading') else {}
|
|
543
605
|
quantization_block_sizes = quantization_config["weight_block_size"]
|
|
544
606
|
assert len(
|
|
545
607
|
quantization_block_sizes
|
|
546
608
|
) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
|
|
547
|
-
quantization_block_size_n, _ = quantization_block_sizes[
|
|
548
|
-
0], quantization_block_sizes[1]
|
|
549
609
|
|
|
550
610
|
# Iterate through all variables and initialize them
|
|
551
|
-
|
|
611
|
+
|
|
552
612
|
for path, param in nnx.iter_graph(model):
|
|
553
613
|
if not isinstance(param, nnx.Variable):
|
|
554
614
|
continue
|
|
@@ -558,16 +618,17 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
|
|
|
558
618
|
is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
|
|
559
619
|
param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
|
|
560
620
|
param_shape = param.value.shape
|
|
561
|
-
# TODO (jacobplatin): clean this up
|
|
562
621
|
if is_qwix_scale:
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
622
|
+
key = f"{path[2]}.{path[3]}"
|
|
623
|
+
|
|
624
|
+
if key in scale_shape_map:
|
|
625
|
+
param_shape = scale_shape_map[key]
|
|
626
|
+
else:
|
|
627
|
+
raise ValueError(
|
|
628
|
+
f"Scale shape for {key} not found in scale_shape_map.")
|
|
567
629
|
param.value = get_random_sharded_array(
|
|
568
630
|
rng, mesh, param, param_shape, param_dtype,
|
|
569
631
|
".".join([str(x) for x in path]))
|
|
570
|
-
prev_param_shape = param_shape
|
|
571
632
|
|
|
572
633
|
# Handles the DeepSeek case, where this needs to be called to make the cache weights
|
|
573
634
|
# concrete
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""Utilities for downloading model weights from HuggingFace."""
|
|
2
15
|
|
|
3
16
|
import functools
|
|
@@ -281,7 +294,8 @@ def _load_and_shard_weight(vllm_config,
|
|
|
281
294
|
hf_key: str,
|
|
282
295
|
hf_weight: jax.Array,
|
|
283
296
|
keep_original_dtype_keys_regex: list[str]
|
|
284
|
-
| None = None
|
|
297
|
+
| None = None,
|
|
298
|
+
pp_missing_layers: list[str] | None = None):
|
|
285
299
|
name_map = metadata_map.name_map
|
|
286
300
|
reshape_keys = metadata_map.reshape_map
|
|
287
301
|
bias_reshape_keys = metadata_map.bias_reshape_map
|
|
@@ -337,6 +351,10 @@ def _load_and_shard_weight(vllm_config,
|
|
|
337
351
|
return
|
|
338
352
|
model_key = name_map.get(hf_key, hf_key)
|
|
339
353
|
|
|
354
|
+
if pp_missing_layers and _is_pp_missing_layer(hf_key, pp_missing_layers):
|
|
355
|
+
logger.warning(
|
|
356
|
+
f"Skip loading {hf_key} as it doesn't belong to this PP stage.")
|
|
357
|
+
return
|
|
340
358
|
model_weight, model_sharding = get_param_and_sharding(
|
|
341
359
|
params, shardings, model_key)
|
|
342
360
|
|
|
@@ -400,6 +418,14 @@ def _load_and_shard_weight(vllm_config,
|
|
|
400
418
|
model_weight.value = shard(hf_weight, spec)
|
|
401
419
|
|
|
402
420
|
|
|
421
|
+
def _is_pp_missing_layer(hf_key: str, pp_missing_layers: list[str]) -> bool:
|
|
422
|
+
has_digit = any(char.isdigit() for char in hf_key)
|
|
423
|
+
# add the suffix after digits to avoid it matches "layers.10" with "layers.1"
|
|
424
|
+
suffix = "." if has_digit else ""
|
|
425
|
+
return any(f'{pp_missing_layer}{suffix}' in hf_key
|
|
426
|
+
for pp_missing_layer in pp_missing_layers)
|
|
427
|
+
|
|
428
|
+
|
|
403
429
|
def _load_hf_weights_on_thread(
|
|
404
430
|
vllm_config: VllmConfig,
|
|
405
431
|
params: nnx.State,
|
|
@@ -408,6 +434,7 @@ def _load_hf_weights_on_thread(
|
|
|
408
434
|
weights_file: str,
|
|
409
435
|
filter_regex: Optional[str] = None,
|
|
410
436
|
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
437
|
+
pp_missing_layers: list[str] | None = None,
|
|
411
438
|
):
|
|
412
439
|
"""Loads weights from a single weights file."""
|
|
413
440
|
try:
|
|
@@ -426,6 +453,7 @@ def _load_hf_weights_on_thread(
|
|
|
426
453
|
hf_key,
|
|
427
454
|
hf_weight,
|
|
428
455
|
keep_original_dtype_keys_regex,
|
|
456
|
+
pp_missing_layers,
|
|
429
457
|
)
|
|
430
458
|
|
|
431
459
|
|
|
@@ -437,6 +465,7 @@ def load_hf_weights(
|
|
|
437
465
|
filter_regex: Optional[str] = None,
|
|
438
466
|
is_draft_model: bool = False,
|
|
439
467
|
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
468
|
+
pp_missing_layers: list[str] | None = None,
|
|
440
469
|
):
|
|
441
470
|
"""Load weights into a JAX model from either an iterator or files."""
|
|
442
471
|
params = nnx.state(model)
|
|
@@ -467,6 +496,7 @@ def load_hf_weights(
|
|
|
467
496
|
hf_key,
|
|
468
497
|
hf_weight_jax,
|
|
469
498
|
keep_original_dtype_keys_regex,
|
|
499
|
+
pp_missing_layers=pp_missing_layers,
|
|
470
500
|
)
|
|
471
501
|
else:
|
|
472
502
|
# File-based path (multi-threaded)
|
|
@@ -494,6 +524,7 @@ def load_hf_weights(
|
|
|
494
524
|
filter_regex=filter_regex,
|
|
495
525
|
keep_original_dtype_keys_regex=
|
|
496
526
|
keep_original_dtype_keys_regex,
|
|
527
|
+
pp_missing_layers=pp_missing_layers,
|
|
497
528
|
) for weights_file in weights_files
|
|
498
529
|
]
|
|
499
530
|
for future in futures:
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import copy
|
|
2
16
|
import functools
|
|
3
17
|
from collections.abc import Sequence
|
|
@@ -23,8 +37,10 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
|
|
|
23
37
|
from vllm.sequence import IntermediateTensors
|
|
24
38
|
|
|
25
39
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
41
|
+
from tpu_inference.layers.vllm.process_weights.cleanup_sharding import \
|
|
42
|
+
shard_model_to_tpu
|
|
26
43
|
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
27
|
-
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
|
|
28
44
|
from tpu_inference.logger import init_logger
|
|
29
45
|
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
30
46
|
JaxIntermediateTensors
|
|
@@ -197,7 +213,7 @@ class VllmModelWrapper:
|
|
|
197
213
|
kwargs={
|
|
198
214
|
"input_ids": torch_view(input_ids),
|
|
199
215
|
"positions": torch_view(input_positions),
|
|
200
|
-
"intermediate_tensors":
|
|
216
|
+
"intermediate_tensors": intermediate_tensors,
|
|
201
217
|
"inputs_embeds": None,
|
|
202
218
|
},
|
|
203
219
|
tie_weights=False,
|
|
@@ -220,8 +236,10 @@ class VllmModelWrapper:
|
|
|
220
236
|
|
|
221
237
|
@functools.partial(
|
|
222
238
|
jax.jit,
|
|
223
|
-
out_shardings=(NamedSharding(
|
|
224
|
-
|
|
239
|
+
out_shardings=(NamedSharding(
|
|
240
|
+
self.mesh,
|
|
241
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
242
|
+
ShardingAxisName.MLP_TENSOR))),
|
|
225
243
|
)
|
|
226
244
|
def compute_logits_func(
|
|
227
245
|
params_and_buffers: Any,
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from contextlib import contextmanager
|
|
2
16
|
from dataclasses import dataclass
|
|
3
17
|
from typing import Dict, List, Optional
|
|
@@ -1,2 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
# ruff: noqa
|
|
2
16
|
from tpu_inference.platforms.tpu_platform import TpuPlatform
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
from typing import TYPE_CHECKING,
|
|
3
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
|
4
4
|
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
import torch
|
|
@@ -15,6 +15,7 @@ from tpu_inference.logger import init_logger
|
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
17
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
18
|
+
from vllm.attention.selector import AttentionSelectorConfig
|
|
18
19
|
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
|
19
20
|
from vllm.pooling_params import PoolingParams
|
|
20
21
|
from vllm.sampling_params import SamplingParams, SamplingType
|
|
@@ -51,11 +52,10 @@ class TpuPlatform(Platform):
|
|
|
51
52
|
|
|
52
53
|
@classmethod
|
|
53
54
|
def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
use_mla: bool, has_sink: bool, use_sparse: bool,
|
|
57
|
-
use_mm_prefix: bool, attn_type: Any) -> str:
|
|
55
|
+
attn_selector_config: "AttentionSelectorConfig",
|
|
56
|
+
**kwargs) -> str:
|
|
58
57
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
58
|
+
|
|
59
59
|
if selected_backend != AttentionBackendEnum.PALLAS:
|
|
60
60
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
|
61
61
|
|
|
@@ -145,17 +145,20 @@ class TpuPlatform(Platform):
|
|
|
145
145
|
compilation_config.backend = "openxla"
|
|
146
146
|
|
|
147
147
|
# TODO(cuiq): remove this dependency.
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
148
|
+
if vllm_config.model_config:
|
|
149
|
+
from vllm.v1.attention.backends.pallas import \
|
|
150
|
+
PallasAttentionBackend
|
|
151
|
+
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
|
152
|
+
vllm_config) # type: ignore[assignment]
|
|
153
|
+
min_page_size = PallasAttentionBackend.get_min_page_size(
|
|
154
|
+
vllm_config)
|
|
155
|
+
if min_page_size > cache_config.block_size:
|
|
156
|
+
logger.warning(
|
|
157
|
+
"Increase the page size from %s to %s to avoid SMEM OOM",
|
|
158
|
+
cache_config.block_size,
|
|
159
|
+
min_page_size,
|
|
160
|
+
)
|
|
161
|
+
cache_config.block_size = min_page_size # type: ignore[assignment]
|
|
159
162
|
|
|
160
163
|
parallel_config = vllm_config.parallel_config
|
|
161
164
|
scheduler_config = vllm_config.scheduler_config
|
|
@@ -165,12 +168,12 @@ class TpuPlatform(Platform):
|
|
|
165
168
|
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
166
169
|
if not multihost_backend: # Single host
|
|
167
170
|
if parallel_config.pipeline_parallel_size == 1:
|
|
168
|
-
logger.info("Force using UniProcExecutor for JAX on
|
|
169
|
-
|
|
171
|
+
logger.info("Force using UniProcExecutor for JAX on "
|
|
172
|
+
"single host without pipeline parallelism.")
|
|
170
173
|
parallel_config.distributed_executor_backend = "uni"
|
|
171
174
|
else:
|
|
172
|
-
logger.info("Force using MultiprocExecutor for JAX on
|
|
173
|
-
|
|
175
|
+
logger.info("Force using MultiprocExecutor for JAX on "
|
|
176
|
+
"single host with pipeline parallelism.")
|
|
174
177
|
parallel_config.distributed_executor_backend = "mp"
|
|
175
178
|
elif multihost_backend == "ray":
|
|
176
179
|
from tpu_inference.executors.ray_distributed_executor import \
|
|
@@ -186,20 +189,15 @@ class TpuPlatform(Platform):
|
|
|
186
189
|
|
|
187
190
|
if scheduler_config.is_multimodal_model and not \
|
|
188
191
|
scheduler_config.disable_chunked_mm_input:
|
|
189
|
-
logger.warning("TPU does not support running Multimodal models"
|
|
190
|
-
|
|
191
|
-
|
|
192
|
+
logger.warning("TPU does not support running Multimodal models"
|
|
193
|
+
" without setting `--disable_chunked_mm_input`. "
|
|
194
|
+
"Forcing --disable_chunked_mm_input.")
|
|
192
195
|
scheduler_config.disable_chunked_mm_input = True
|
|
193
196
|
|
|
194
197
|
kv_transfer_config = vllm_config.kv_transfer_config
|
|
195
198
|
if kv_transfer_config is not None:
|
|
196
199
|
assert kv_transfer_config.kv_connector == "TPUConnector"
|
|
197
|
-
# Late initialization to avoid circular import
|
|
198
|
-
from tpu_inference.models.jax.utils.quantization.quantization_utils import \
|
|
199
|
-
update_vllm_config_for_qwix_quantization
|
|
200
|
-
|
|
201
|
-
update_vllm_config_for_qwix_quantization(vllm_config)
|
|
202
|
-
|
|
200
|
+
# Late initialization to avoid circular import.
|
|
203
201
|
from tpu_inference.core.sched.dp_scheduler import \
|
|
204
202
|
update_vllm_config_for_dp_scheduler
|
|
205
203
|
update_vllm_config_for_dp_scheduler(vllm_config)
|
tpu_inference/runner/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import time
|
|
2
16
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
3
17
|
|
|
@@ -32,6 +46,8 @@ class CompilationManager:
|
|
|
32
46
|
|
|
33
47
|
def __init__(self, runner: "TPUModelRunner"):
|
|
34
48
|
self.runner = runner
|
|
49
|
+
self._sampling_precompiled = False
|
|
50
|
+
self._gather_logprobs_precompiled = False
|
|
35
51
|
if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
|
|
36
52
|
logger.info("Enabling JAX compile cache.")
|
|
37
53
|
jax.config.update("jax_compilation_cache_dir",
|
|
@@ -86,9 +102,13 @@ class CompilationManager:
|
|
|
86
102
|
return
|
|
87
103
|
self._precompile_select_from_array()
|
|
88
104
|
self._precompile_compute_logits()
|
|
105
|
+
# Skip sampling if already precompiled before KV cache allocation
|
|
106
|
+
if not self._sampling_precompiled:
|
|
107
|
+
self._precompile_sampling()
|
|
89
108
|
self._precompile_disagg_utils()
|
|
90
|
-
|
|
91
|
-
self.
|
|
109
|
+
# Skip gather_logprobs if already precompiled before KV cache allocation
|
|
110
|
+
if not self._gather_logprobs_precompiled:
|
|
111
|
+
self._precompile_gather_logprobs()
|
|
92
112
|
self._precompile_structured_decoding()
|
|
93
113
|
if self.runner.speculative_config:
|
|
94
114
|
self._precompile_speculative_decoding()
|
|
@@ -107,7 +127,7 @@ class CompilationManager:
|
|
|
107
127
|
|
|
108
128
|
self._run_compilation(
|
|
109
129
|
"input_embeddings_merger",
|
|
110
|
-
self.runner.
|
|
130
|
+
self.runner.embed_input_ids_fn,
|
|
111
131
|
self.runner.state,
|
|
112
132
|
dummy_input_ids,
|
|
113
133
|
dummy_multimodal_embeddings,
|
|
@@ -116,7 +136,7 @@ class CompilationManager:
|
|
|
116
136
|
|
|
117
137
|
self._run_compilation(
|
|
118
138
|
"input_embeddings_merger_text_only",
|
|
119
|
-
self.runner.
|
|
139
|
+
self.runner.embed_input_ids_fn,
|
|
120
140
|
self.runner.state,
|
|
121
141
|
dummy_input_ids,
|
|
122
142
|
None,
|
|
@@ -475,35 +495,39 @@ class CompilationManager:
|
|
|
475
495
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
476
496
|
logits_sharding)
|
|
477
497
|
for do_sampling in (True, False):
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
498
|
+
for logprobs in (True, False):
|
|
499
|
+
if do_sampling:
|
|
500
|
+
temperature = np.full((num_reqs, ),
|
|
501
|
+
0.7,
|
|
502
|
+
dtype=np.float32)
|
|
503
|
+
top_k = np.full((num_reqs, ), 20, dtype=np.int32)
|
|
504
|
+
top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
|
|
505
|
+
(temperature, top_k, top_p) = device_array(
|
|
506
|
+
self.runner.mesh, (temperature, top_k, top_p),
|
|
507
|
+
sharding=sampling_metadata_sharding)
|
|
508
|
+
else:
|
|
509
|
+
temperature = None
|
|
510
|
+
top_k = None
|
|
511
|
+
top_p = None
|
|
512
|
+
|
|
513
|
+
sampling_metadata = TPUSupportedSamplingMetadata(
|
|
514
|
+
temperature=temperature,
|
|
515
|
+
top_k=top_k,
|
|
516
|
+
top_p=top_p,
|
|
517
|
+
do_sampling=do_sampling,
|
|
518
|
+
logprobs=logprobs)
|
|
519
|
+
self._run_compilation(
|
|
520
|
+
f"worker{self.runner.rank} sample",
|
|
521
|
+
sample,
|
|
522
|
+
self.runner.rng_params_for_sampling,
|
|
523
|
+
self.runner.mesh,
|
|
524
|
+
logits,
|
|
525
|
+
sampling_metadata,
|
|
526
|
+
num_reqs=num_reqs,
|
|
527
|
+
do_sampling=do_sampling,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
self._sampling_precompiled = True
|
|
507
531
|
|
|
508
532
|
def _precompile_disagg_utils(self) -> None:
|
|
509
533
|
if not is_disagg_enabled():
|
|
@@ -533,8 +557,16 @@ class CompilationManager:
|
|
|
533
557
|
logger.info("Compiling gather_logprobs with different input shapes.")
|
|
534
558
|
hsize = self.runner.model_config.get_vocab_size()
|
|
535
559
|
for num_reqs in self.runner.num_reqs_paddings:
|
|
536
|
-
|
|
537
|
-
|
|
560
|
+
logits_sharding = NamedSharding(
|
|
561
|
+
self.runner.mesh,
|
|
562
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
563
|
+
ShardingAxisName.MLP_TENSOR))
|
|
564
|
+
token_ids_sharding = NamedSharding(
|
|
565
|
+
self.runner.mesh, PartitionSpec(ShardingAxisName.MLP_DATA, ))
|
|
566
|
+
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
567
|
+
logits_sharding)
|
|
568
|
+
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32,
|
|
569
|
+
token_ids_sharding)
|
|
538
570
|
self._run_compilation(
|
|
539
571
|
f"worker{self.runner.rank} gather_logprobs",
|
|
540
572
|
self.runner._compute_and_gather_logprobs,
|
|
@@ -544,6 +576,8 @@ class CompilationManager:
|
|
|
544
576
|
num_reqs=num_reqs,
|
|
545
577
|
)
|
|
546
578
|
|
|
579
|
+
self._gather_logprobs_precompiled = True
|
|
580
|
+
|
|
547
581
|
def _precompile_speculative_decoding(self) -> None:
|
|
548
582
|
logger.info(
|
|
549
583
|
"Compiling speculative_decoding with different input shapes.")
|