tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,529 @@
|
|
|
1
|
+
"""Utilities for downloading model weights from HuggingFace."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import glob
|
|
5
|
+
import math
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
from collections.abc import Generator
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any, Optional
|
|
12
|
+
|
|
13
|
+
import jax
|
|
14
|
+
import jax.numpy as jnp
|
|
15
|
+
import torch
|
|
16
|
+
from flax import nnx
|
|
17
|
+
from jax.sharding import Mesh, NamedSharding
|
|
18
|
+
from jax.sharding import PartitionSpec as P
|
|
19
|
+
from safetensors import safe_open
|
|
20
|
+
|
|
21
|
+
from tpu_inference import utils
|
|
22
|
+
from tpu_inference.logger import init_logger
|
|
23
|
+
from tpu_inference.models.jax.utils import file_utils
|
|
24
|
+
|
|
25
|
+
logger = init_logger(__name__)
|
|
26
|
+
|
|
27
|
+
HF_WEIGHTS_FORMAT = "*.safetensors"
|
|
28
|
+
|
|
29
|
+
DTYPE_VIEW_MAP = {
|
|
30
|
+
jnp.dtype(jnp.float8_e4m3fn): torch.uint8,
|
|
31
|
+
jnp.dtype(jnp.bfloat16): torch.uint16,
|
|
32
|
+
jnp.dtype(jnp.float32): torch.uint32,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class MetadataMap:
|
|
38
|
+
name_map: dict[str, str] = field(default_factory=dict)
|
|
39
|
+
transpose_map: dict[str, tuple[int, ...]] = field(default_factory=dict)
|
|
40
|
+
reshape_map: dict[str, tuple[int, ...]] = field(default_factory=dict)
|
|
41
|
+
bias_reshape_map: dict[str, tuple[int, ...]] = field(default_factory=dict)
|
|
42
|
+
pad_map: dict[str, tuple[int, ...]] = field(default_factory=dict)
|
|
43
|
+
bias_pad_map: dict[str, tuple[int, ...]] = field(default_factory=dict)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
############ START Used by llama4, deepseek only for now START ############
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def print_param_info(param: nnx.Param, name: str):
|
|
50
|
+
logger.warning(f"Global shape for {name}: {param.value.shape}")
|
|
51
|
+
logger.warning(f"Sharding for {name}: {param.sharding}")
|
|
52
|
+
|
|
53
|
+
logger.warning(
|
|
54
|
+
f"Shape of {name} on a single device: {param.value.addressable_shards[0].data.shape}"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def transpose_params(param_key: str, param_tensor: jax.Array, transpose_map):
|
|
59
|
+
for key, value in transpose_map.items():
|
|
60
|
+
if key in param_key:
|
|
61
|
+
return jnp.transpose(param_tensor, value)
|
|
62
|
+
return param_tensor # Base case / no-op
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def reshape_params(param_key: str, param_tensor: jax.Array, shape_map):
|
|
66
|
+
for key, new_shape in shape_map.items():
|
|
67
|
+
if key in param_key:
|
|
68
|
+
return jnp.reshape(param_tensor, new_shape)
|
|
69
|
+
return param_tensor # Base case / no-op
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def model_file_generator(
|
|
73
|
+
model_name_or_path: str,
|
|
74
|
+
download_dir: Optional[str]) -> Generator[str, None, None]:
|
|
75
|
+
weights_files = get_model_weights_files(model_name_or_path, download_dir)
|
|
76
|
+
for st_file in weights_files:
|
|
77
|
+
yield st_file
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def model_weights_generator(
|
|
81
|
+
model_name_or_path: str,
|
|
82
|
+
framework: str,
|
|
83
|
+
filter_regex: Optional[str] = None,
|
|
84
|
+
download_dir: Optional[str] = None,
|
|
85
|
+
) -> Generator[tuple, None, None]:
|
|
86
|
+
for st_file in model_file_generator(model_name_or_path, download_dir):
|
|
87
|
+
for name, weight_tensor in model_weights_single_file_generator(
|
|
88
|
+
st_file, framework, filter_regex):
|
|
89
|
+
yield name, weight_tensor
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def convert_torch_to_jax_with_view(loaded_weight: torch.Tensor,
|
|
93
|
+
cast_type: jnp.dtype) -> jax.Array:
|
|
94
|
+
"""
|
|
95
|
+
Converts a PyTorch tensor to a JAX array by reinterpreting its
|
|
96
|
+
bit representation using a dtype view map.
|
|
97
|
+
"""
|
|
98
|
+
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
|
|
99
|
+
loaded_weight = jnp.array(
|
|
100
|
+
loaded_weight.view(torch_view_type).numpy()).view(cast_type)
|
|
101
|
+
return loaded_weight
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
############ END Used by llama4, deepseek only for now END ############
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def get_model_weights_files(
|
|
108
|
+
model_name_or_path: str,
|
|
109
|
+
download_dir: Optional[str]) -> tuple[list[str], str]:
|
|
110
|
+
"""
|
|
111
|
+
Helper to get weight files and their location.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
if os.path.isdir(model_name_or_path):
|
|
115
|
+
logger.info(f"Found weights from local: {model_name_or_path}")
|
|
116
|
+
weights_files = glob.glob(
|
|
117
|
+
os.path.join(model_name_or_path, HF_WEIGHTS_FORMAT))
|
|
118
|
+
elif file_utils.is_hf_repo(model_name_or_path):
|
|
119
|
+
logger.info(f"Downloading weights from HF {model_name_or_path}")
|
|
120
|
+
weights_files = file_utils.download_model_weights_from_hf(
|
|
121
|
+
model_name_or_path, download_dir, HF_WEIGHTS_FORMAT)
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"{model_name_or_path} must be a local directory, or a Huggingface model id."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if not weights_files:
|
|
128
|
+
raise RuntimeError(
|
|
129
|
+
f"Cannot find any {HF_WEIGHTS_FORMAT} files in {model_name_or_path}."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
weights_files.sort()
|
|
133
|
+
return weights_files
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def model_weights_single_file_generator(
|
|
137
|
+
weights_file: str,
|
|
138
|
+
framework: str,
|
|
139
|
+
filter_regex: Optional[str] = None,
|
|
140
|
+
) -> Generator[tuple, None, None]:
|
|
141
|
+
logger.info(f"Loading weights from {weights_file}")
|
|
142
|
+
# NOTE: We enforce loading tensors on CPU here.
|
|
143
|
+
# Because otherwise the tensor will be loaded on TPU:0 by default,
|
|
144
|
+
# although the tensor would eventually be sharded across multiple TPUs,
|
|
145
|
+
# it would lead to OOM on TPU:0 for large models.
|
|
146
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
147
|
+
with safe_open(weights_file, framework=framework) as f:
|
|
148
|
+
for name in f.keys():
|
|
149
|
+
if filter_regex is not None and not re.match(
|
|
150
|
+
filter_regex, name):
|
|
151
|
+
continue
|
|
152
|
+
weight_tensor = f.get_tensor(name)
|
|
153
|
+
yield name, weight_tensor
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_param(params: nnx.State, path: str) -> nnx.State:
|
|
157
|
+
keys = path.split(".")
|
|
158
|
+
plevel = params
|
|
159
|
+
for key in keys:
|
|
160
|
+
if key.isdigit():
|
|
161
|
+
plevel = plevel[int(key)]
|
|
162
|
+
else:
|
|
163
|
+
if key in plevel:
|
|
164
|
+
plevel = plevel[key]
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError(f"{path} is not a valid param path")
|
|
167
|
+
return plevel
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def get_param_and_sharding(params: nnx.State, shardings: Any,
|
|
171
|
+
path: str) -> tuple[nnx.State, nnx.State]:
|
|
172
|
+
keys = path.split(".")
|
|
173
|
+
plevel = params
|
|
174
|
+
slevel = shardings
|
|
175
|
+
for key in keys:
|
|
176
|
+
if key.isdigit():
|
|
177
|
+
plevel = plevel[int(key)]
|
|
178
|
+
slevel = slevel[int(key)]
|
|
179
|
+
else:
|
|
180
|
+
if key in plevel:
|
|
181
|
+
plevel = plevel[key]
|
|
182
|
+
slevel = slevel[key]
|
|
183
|
+
else:
|
|
184
|
+
raise ValueError(f"{path} is not a valid param path")
|
|
185
|
+
return plevel, slevel.value
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
|
|
189
|
+
# Single device sharding requires this special handling
|
|
190
|
+
# to avoid the recursive jit error.
|
|
191
|
+
if math.prod(mesh.axis_sizes) == 1:
|
|
192
|
+
return jax.device_put(x, mesh.devices.flatten()[0])
|
|
193
|
+
|
|
194
|
+
if isinstance(shardings, tuple):
|
|
195
|
+
return jax.device_put(x, NamedSharding(mesh, P(*shardings)))
|
|
196
|
+
else:
|
|
197
|
+
return jax.device_put(x, shardings)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def get_default_maps(vllm_config, mesh: Mesh,
|
|
201
|
+
name_map: dict[str, str]) -> MetadataMap:
|
|
202
|
+
"""Load weights from one model weights file to the model, run on single thread."""
|
|
203
|
+
sharding_size = mesh.shape["model"]
|
|
204
|
+
|
|
205
|
+
model_config = vllm_config.model_config
|
|
206
|
+
hf_config = model_config.hf_config
|
|
207
|
+
|
|
208
|
+
num_heads = hf_config.num_attention_heads
|
|
209
|
+
num_kv_heads = hf_config.num_key_value_heads
|
|
210
|
+
hidden_size = model_config.get_hidden_size()
|
|
211
|
+
|
|
212
|
+
# Pad head_dim for kernel performance.
|
|
213
|
+
head_dim_original = model_config.get_head_size()
|
|
214
|
+
|
|
215
|
+
reshape_keys: dict[str, tuple[int, ...]] = {
|
|
216
|
+
"q_proj": (num_heads, head_dim_original, hidden_size),
|
|
217
|
+
"k_proj": (num_kv_heads, head_dim_original, hidden_size),
|
|
218
|
+
"v_proj": (num_kv_heads, head_dim_original, hidden_size),
|
|
219
|
+
"o_proj": (hidden_size, num_heads, head_dim_original),
|
|
220
|
+
}
|
|
221
|
+
bias_reshape_keys: dict[str, tuple[int, ...]] = {
|
|
222
|
+
"q_proj.bias": (num_heads, head_dim_original),
|
|
223
|
+
"k_proj.bias": (num_kv_heads, head_dim_original),
|
|
224
|
+
"v_proj.bias": (num_kv_heads, head_dim_original)
|
|
225
|
+
}
|
|
226
|
+
transpose_keys: dict[str, tuple[int, ...]] = {
|
|
227
|
+
"lm_head": (1, 0),
|
|
228
|
+
"fc": (1, 0),
|
|
229
|
+
"gate_proj": (1, 0),
|
|
230
|
+
"up_proj": (1, 0),
|
|
231
|
+
"down_proj": (1, 0),
|
|
232
|
+
"q_proj": (2, 0, 1),
|
|
233
|
+
"k_proj": (2, 0, 1),
|
|
234
|
+
"v_proj": (2, 0, 1),
|
|
235
|
+
"o_proj": (1, 2, 0),
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
# # get vision config
|
|
239
|
+
if model_config.is_multimodal_model:
|
|
240
|
+
# TODO: Wenlong: Do not consider padding for now
|
|
241
|
+
transpose_keys.update({
|
|
242
|
+
"attn.proj": (1, 0),
|
|
243
|
+
"attn.qkv": (1, 0),
|
|
244
|
+
"visual.merger.mlp": (1, 0),
|
|
245
|
+
"visual.patch_embed.proj": (2, 3, 4, 1, 0),
|
|
246
|
+
})
|
|
247
|
+
|
|
248
|
+
# key: (padding_dim, padding_size)
|
|
249
|
+
pad_keys: dict[str, tuple[int, ...]] = {
|
|
250
|
+
"q_proj": (1, sharding_size // num_heads),
|
|
251
|
+
"k_proj": (1, sharding_size // num_kv_heads),
|
|
252
|
+
"v_proj": (1, sharding_size // num_kv_heads),
|
|
253
|
+
"o_proj": (0, sharding_size // num_heads),
|
|
254
|
+
}
|
|
255
|
+
bias_pad_keys: dict[str, tuple[int, ...]] = {
|
|
256
|
+
"q_proj.bias": (0, sharding_size // num_heads),
|
|
257
|
+
"k_proj.bias": (0, sharding_size // num_kv_heads),
|
|
258
|
+
"v_proj.bias": (0, sharding_size // num_kv_heads),
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
return MetadataMap(name_map=name_map,
|
|
262
|
+
reshape_map=reshape_keys,
|
|
263
|
+
bias_reshape_map=bias_reshape_keys,
|
|
264
|
+
transpose_map=transpose_keys,
|
|
265
|
+
pad_map=pad_keys,
|
|
266
|
+
bias_pad_map=bias_pad_keys)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _load_hf_weights_on_thread(vllm_config,
|
|
270
|
+
params: nnx.State,
|
|
271
|
+
metadata_map: MetadataMap,
|
|
272
|
+
mesh: Mesh,
|
|
273
|
+
weights_file: str,
|
|
274
|
+
filter_regex: str | None = None,
|
|
275
|
+
keep_original_dtype_keys_regex: list[str]
|
|
276
|
+
| None = None):
|
|
277
|
+
name_map = metadata_map.name_map
|
|
278
|
+
reshape_keys = metadata_map.reshape_map
|
|
279
|
+
bias_reshape_keys = metadata_map.bias_reshape_map
|
|
280
|
+
transpose_keys = metadata_map.transpose_map
|
|
281
|
+
pad_keys = metadata_map.pad_map
|
|
282
|
+
bias_pad_keys = metadata_map.bias_pad_map
|
|
283
|
+
|
|
284
|
+
shard = functools.partial(shard_put, mesh=mesh)
|
|
285
|
+
|
|
286
|
+
model_config = vllm_config.model_config
|
|
287
|
+
|
|
288
|
+
# Pad head_dim for kernel performance.
|
|
289
|
+
head_dim_original = model_config.get_head_size()
|
|
290
|
+
head_dim = utils.get_padded_head_dim(head_dim_original)
|
|
291
|
+
head_dim_pad = head_dim - head_dim_original
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
shardings = nnx.get_named_sharding(params, mesh)
|
|
295
|
+
except TypeError:
|
|
296
|
+
shardings = params
|
|
297
|
+
|
|
298
|
+
for hf_key, hf_weight in model_weights_single_file_generator(
|
|
299
|
+
weights_file, framework="flax", filter_regex=filter_regex):
|
|
300
|
+
|
|
301
|
+
# Check if the key should retain its original dtype
|
|
302
|
+
keep_original_dtype = False
|
|
303
|
+
if keep_original_dtype_keys_regex:
|
|
304
|
+
for pattern in keep_original_dtype_keys_regex:
|
|
305
|
+
if re.match(pattern, hf_key):
|
|
306
|
+
keep_original_dtype = True
|
|
307
|
+
break
|
|
308
|
+
|
|
309
|
+
# Converting to config's dtype
|
|
310
|
+
if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
|
|
311
|
+
logger.warning(
|
|
312
|
+
f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
|
|
313
|
+
)
|
|
314
|
+
hf_weight = hf_weight.astype(model_config.dtype)
|
|
315
|
+
|
|
316
|
+
if hf_key.endswith(".weight"):
|
|
317
|
+
hf_key = hf_key.removesuffix(".weight")
|
|
318
|
+
|
|
319
|
+
# Find the corresponding model key using the HF key
|
|
320
|
+
if "layers" in hf_key:
|
|
321
|
+
layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
|
|
322
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
|
|
323
|
+
model_key = name_map[layer_key]
|
|
324
|
+
model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
|
|
325
|
+
elif "blocks" in hf_key:
|
|
326
|
+
layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
|
|
327
|
+
layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
|
|
328
|
+
model_key = name_map[layer_key]
|
|
329
|
+
model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
|
|
330
|
+
else:
|
|
331
|
+
if hf_key not in name_map and hf_key == "lm_head":
|
|
332
|
+
logger.warning(
|
|
333
|
+
f"Skip loading {hf_key} due to tie_word_embeddings")
|
|
334
|
+
continue
|
|
335
|
+
if hf_key not in name_map and "t2d" in hf_key:
|
|
336
|
+
logger.warning(
|
|
337
|
+
f"Skip loading {hf_key} as it's not used in eagle-3 for now"
|
|
338
|
+
)
|
|
339
|
+
continue
|
|
340
|
+
model_key = name_map.get(hf_key, hf_key)
|
|
341
|
+
model_weight, model_sharding = get_param_and_sharding(
|
|
342
|
+
params, shardings, model_key)
|
|
343
|
+
|
|
344
|
+
logger.debug(
|
|
345
|
+
"before transform | "
|
|
346
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
if hf_key.endswith(".bias"):
|
|
350
|
+
for key in bias_reshape_keys:
|
|
351
|
+
if key in hf_key:
|
|
352
|
+
hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
|
|
353
|
+
if head_dim_pad > 0:
|
|
354
|
+
hf_weight = jnp.pad(hf_weight,
|
|
355
|
+
((0, 0), (0, head_dim_pad)))
|
|
356
|
+
break
|
|
357
|
+
else:
|
|
358
|
+
for key in reshape_keys:
|
|
359
|
+
if key in hf_key:
|
|
360
|
+
hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
|
|
361
|
+
if head_dim_pad > 0:
|
|
362
|
+
if "o_proj" in key:
|
|
363
|
+
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
|
|
364
|
+
(0, head_dim_pad)))
|
|
365
|
+
else:
|
|
366
|
+
hf_weight = jnp.pad(hf_weight,
|
|
367
|
+
((0, 0), (0, head_dim_pad),
|
|
368
|
+
(0, 0)))
|
|
369
|
+
break
|
|
370
|
+
for key in transpose_keys:
|
|
371
|
+
if key in hf_key:
|
|
372
|
+
hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
|
|
373
|
+
break
|
|
374
|
+
|
|
375
|
+
# Pad num-kv-heads
|
|
376
|
+
if hf_key.endswith(".bias"):
|
|
377
|
+
for key, value in bias_pad_keys.items():
|
|
378
|
+
dim = value[0]
|
|
379
|
+
dim_size = value[1]
|
|
380
|
+
if key in hf_key and dim_size != 0:
|
|
381
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
382
|
+
break
|
|
383
|
+
else:
|
|
384
|
+
for key, value in pad_keys.items():
|
|
385
|
+
dim = value[0]
|
|
386
|
+
dim_size = value[1]
|
|
387
|
+
if key in hf_key and dim_size != 0:
|
|
388
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
389
|
+
break
|
|
390
|
+
|
|
391
|
+
logger.debug(
|
|
392
|
+
"after transform | "
|
|
393
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
if head_dim_pad == 0:
|
|
397
|
+
assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
|
|
398
|
+
|
|
399
|
+
# Update the model weight
|
|
400
|
+
spec = model_weight.sharding.spec if isinstance(
|
|
401
|
+
model_weight.sharding, NamedSharding) else model_weight.sharding
|
|
402
|
+
model_weight.value = shard(hf_weight, spec)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def load_hf_weights(vllm_config,
|
|
406
|
+
model: nnx.Module,
|
|
407
|
+
metadata_map: MetadataMap,
|
|
408
|
+
mesh: Mesh,
|
|
409
|
+
filter_regex: str | None = None,
|
|
410
|
+
is_draft_model: bool = False,
|
|
411
|
+
keep_original_dtype_keys_regex: list[str] | None = None):
|
|
412
|
+
"""Load weights from all model weights files to the model, run in multi threads."""
|
|
413
|
+
if is_draft_model:
|
|
414
|
+
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
415
|
+
else:
|
|
416
|
+
model_path = vllm_config.model_config.model
|
|
417
|
+
weights_files = get_model_weights_files(
|
|
418
|
+
model_path, vllm_config.load_config.download_dir)
|
|
419
|
+
params = nnx.state(model)
|
|
420
|
+
max_workers = min(64, len(weights_files))
|
|
421
|
+
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
422
|
+
# Because multi-threading would cause different JAX processes to load
|
|
423
|
+
# different weights at the same time.
|
|
424
|
+
if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
|
|
425
|
+
max_workers = 1
|
|
426
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
427
|
+
futures = [
|
|
428
|
+
executor.submit(
|
|
429
|
+
_load_hf_weights_on_thread,
|
|
430
|
+
vllm_config,
|
|
431
|
+
params,
|
|
432
|
+
metadata_map,
|
|
433
|
+
mesh,
|
|
434
|
+
weights_file,
|
|
435
|
+
filter_regex=filter_regex,
|
|
436
|
+
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
|
|
437
|
+
for weights_file in weights_files
|
|
438
|
+
]
|
|
439
|
+
for future in futures:
|
|
440
|
+
future.result()
|
|
441
|
+
check_all_loaded(params)
|
|
442
|
+
nnx.update(model, params)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def check_all_loaded(params: nnx.State):
|
|
446
|
+
|
|
447
|
+
def _check(x: Any):
|
|
448
|
+
if isinstance(x, nnx.Param) and isinstance(x.value,
|
|
449
|
+
jax.ShapeDtypeStruct):
|
|
450
|
+
raise ValueError(f"The param does not load weights: {x}")
|
|
451
|
+
|
|
452
|
+
jax.tree.map(_check, params)
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def build_flat_dict(flat_state, mappings):
|
|
456
|
+
"""Build a new flat dictionary from the flat state using the provided mappings."""
|
|
457
|
+
new_flat_dict = {}
|
|
458
|
+
for keys, v in flat_state:
|
|
459
|
+
path = '.'.join(str(key) for key in keys)
|
|
460
|
+
mapped = False
|
|
461
|
+
for src, (tgt, sharding) in mappings.items():
|
|
462
|
+
regex = "^" + re.escape(tgt).replace("\\.\\*", r"\.(\d+)") + "$"
|
|
463
|
+
matched = re.match(regex, path)
|
|
464
|
+
if matched:
|
|
465
|
+
# Extract wildcards if any
|
|
466
|
+
wildcards = matched.groups()
|
|
467
|
+
src_parts = []
|
|
468
|
+
wc_index = 0
|
|
469
|
+
for part in src.split("."):
|
|
470
|
+
if part == "*":
|
|
471
|
+
src_parts.append(wildcards[wc_index])
|
|
472
|
+
wc_index += 1
|
|
473
|
+
else:
|
|
474
|
+
src_parts.append(part)
|
|
475
|
+
actual_src = ".".join(src_parts)
|
|
476
|
+
new_flat_dict[actual_src] = v, sharding
|
|
477
|
+
mapped = True
|
|
478
|
+
break
|
|
479
|
+
if not mapped:
|
|
480
|
+
logger.info(f"!!! No mapping for flat state: {keys}")
|
|
481
|
+
return new_flat_dict
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def transfer_state_with_mappings(src_state,
|
|
485
|
+
tgt_state,
|
|
486
|
+
mappings,
|
|
487
|
+
transpose_keys=None,
|
|
488
|
+
shard=None):
|
|
489
|
+
"""Transfer state from src_state to tgt_state using the provided mappings."""
|
|
490
|
+
src_flat = src_state.flat_state()
|
|
491
|
+
tgt_flat = tgt_state.flat_state()
|
|
492
|
+
|
|
493
|
+
new_src_dict = build_flat_dict(tgt_flat, mappings)
|
|
494
|
+
logger.info(f"{mappings=}")
|
|
495
|
+
logger.info(f"{transpose_keys=}")
|
|
496
|
+
for src_keys, v in src_flat:
|
|
497
|
+
flattened_src_keys = '.'.join(str(k) for k in src_keys)
|
|
498
|
+
new_v = jnp.copy(v.value)
|
|
499
|
+
logger.info(
|
|
500
|
+
f"Processing source key: {flattened_src_keys} and value: {new_v.shape} {new_v.dtype}"
|
|
501
|
+
)
|
|
502
|
+
if flattened_src_keys not in new_src_dict:
|
|
503
|
+
logger.info(f"!!! No mapping for source key: {flattened_src_keys}")
|
|
504
|
+
continue
|
|
505
|
+
sharding = new_src_dict[flattened_src_keys][1]
|
|
506
|
+
|
|
507
|
+
# E.g. layers.*.attn.k_proj.w, layers.*.attn.k_proj.w_lora_a
|
|
508
|
+
# E.g. layers.*.mlp.down_proj.kernel, layers.*.mlp.down_proj.kernel_lora_a
|
|
509
|
+
if transpose_keys is not None \
|
|
510
|
+
and ((src_keys[-1] in transpose_keys) and ('lora' not in src_keys[-1])):
|
|
511
|
+
v_maybe_t = jnp.transpose(new_v, transpose_keys[src_keys[-1]])
|
|
512
|
+
else:
|
|
513
|
+
v_maybe_t = new_v
|
|
514
|
+
|
|
515
|
+
to_update_value = new_src_dict[flattened_src_keys][0].value
|
|
516
|
+
assert to_update_value.shape == v_maybe_t.shape, \
|
|
517
|
+
f"Shape mismatch for {flattened_src_keys}: {to_update_value.shape} vs {v_maybe_t.shape}"
|
|
518
|
+
|
|
519
|
+
if to_update_value.dtype != v_maybe_t.dtype:
|
|
520
|
+
logger.info(
|
|
521
|
+
f"Type mismatch between external model and vLLM model. Converting {v_maybe_t.dtype=} to {to_update_value.dtype=}"
|
|
522
|
+
)
|
|
523
|
+
v_maybe_t = v_maybe_t.astype(to_update_value.dtype)
|
|
524
|
+
|
|
525
|
+
new_src_dict[flattened_src_keys][0].value = shard(
|
|
526
|
+
v_maybe_t, sharding) if shard else v_maybe_t
|
|
527
|
+
|
|
528
|
+
tgt_state = tgt_state.from_flat_path(tgt_flat)
|
|
529
|
+
return tgt_state
|
|
File without changes
|