tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,621 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utilities for downloading model weights from HuggingFace."""
15
+
16
+ import functools
17
+ import glob
18
+ import math
19
+ import os
20
+ import re
21
+ from collections.abc import Generator
22
+ from concurrent.futures import ThreadPoolExecutor
23
+ from dataclasses import dataclass, field
24
+ from typing import Any, Optional
25
+
26
+ import jax
27
+ import jax.numpy as jnp
28
+ import torch
29
+ import torchax
30
+ from flax import nnx
31
+ from jax.sharding import Mesh, NamedSharding
32
+ from jax.sharding import PartitionSpec as P
33
+ from safetensors import safe_open
34
+ from vllm.config import VllmConfig
35
+
36
+ from tpu_inference import envs, utils
37
+ from tpu_inference.logger import init_logger
38
+ from tpu_inference.models.jax.utils import file_utils
39
+
40
+ logger = init_logger(__name__)
41
+
42
+ HF_WEIGHTS_FORMAT = "*.safetensors"
43
+
44
+ DTYPE_VIEW_MAP = {
45
+ jnp.dtype(jnp.float8_e4m3fn): torch.uint8,
46
+ jnp.dtype(jnp.bfloat16): torch.uint16,
47
+ jnp.dtype(jnp.float32): torch.uint32,
48
+ }
49
+
50
+
51
+ @dataclass
52
+ class MetadataMap:
53
+ name_map: dict[str, str] = field(default_factory=dict)
54
+ transpose_map: dict[str, tuple[int, ...]] = field(default_factory=dict)
55
+ reshape_map: dict[str, tuple[int, ...]] = field(default_factory=dict)
56
+ bias_reshape_map: dict[str, tuple[int, ...]] = field(default_factory=dict)
57
+ pad_map: dict[str, tuple[int, ...]] = field(default_factory=dict)
58
+ bias_pad_map: dict[str, tuple[int, ...]] = field(default_factory=dict)
59
+
60
+
61
+ ############ START Used by llama4, deepseek only for now START ############
62
+
63
+
64
+ def print_param_info(param: nnx.Param, name: str):
65
+ logger.warning(f"Global shape for {name}: {param.value.shape}")
66
+ logger.warning(f"Sharding for {name}: {param.sharding}")
67
+
68
+ logger.warning(
69
+ f"Shape of {name} on a single device: {param.value.addressable_shards[0].data.shape}"
70
+ )
71
+
72
+
73
+ def transpose_params(param_key: str, param_tensor: jax.Array, transpose_map):
74
+ for key, value in transpose_map.items():
75
+ if key in param_key:
76
+ return jnp.transpose(param_tensor, value)
77
+ return param_tensor # Base case / no-op
78
+
79
+
80
+ def reshape_params(param_key: str, param_tensor: jax.Array, shape_map):
81
+ for key, new_shape in shape_map.items():
82
+ if key in param_key:
83
+ try:
84
+ #TODO:(gpolovets) Add validation on whether reshape preserves data layout.
85
+ return jnp.reshape(param_tensor, new_shape)
86
+ except TypeError:
87
+ raise TypeError(
88
+ f"Cannot reshape for key={key}, new_shape={new_shape}, param_shape={param_tensor.shape}"
89
+ )
90
+ return param_tensor # Base case / no-op
91
+
92
+
93
+ def model_file_generator(
94
+ model_name_or_path: str,
95
+ download_dir: Optional[str]) -> Generator[str, None, None]:
96
+ weights_files = get_model_weights_files(model_name_or_path, download_dir)
97
+ for st_file in weights_files:
98
+ yield st_file
99
+
100
+
101
+ def model_weights_generator(
102
+ model_name_or_path: str,
103
+ framework: str,
104
+ filter_regex: Optional[str] = None,
105
+ download_dir: Optional[str] = None,
106
+ ) -> Generator[tuple, None, None]:
107
+ for st_file in model_file_generator(model_name_or_path, download_dir):
108
+ for name, weight_tensor in model_weights_single_file_generator(
109
+ st_file, framework, filter_regex):
110
+ yield name, weight_tensor
111
+
112
+
113
+ def convert_torch_to_jax_with_view(loaded_weight: torch.Tensor,
114
+ cast_type: jnp.dtype) -> jax.Array:
115
+ """
116
+ Converts a PyTorch tensor to a JAX array by reinterpreting its
117
+ bit representation using a dtype view map.
118
+ """
119
+ torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
120
+ loaded_weight = jnp.array(
121
+ loaded_weight.view(torch_view_type).numpy()).view(cast_type)
122
+ return loaded_weight
123
+
124
+
125
+ ############ END Used by llama4, deepseek only for now END ############
126
+
127
+
128
+ def get_model_weights_files(
129
+ model_name_or_path: str,
130
+ download_dir: Optional[str]) -> tuple[list[str], str]:
131
+ """
132
+ Helper to get weight files and their location.
133
+ """
134
+
135
+ if os.path.isdir(model_name_or_path):
136
+ logger.info(f"Found weights from local: {model_name_or_path}")
137
+ weights_files = glob.glob(
138
+ os.path.join(model_name_or_path, HF_WEIGHTS_FORMAT))
139
+ elif file_utils.is_hf_repo(model_name_or_path):
140
+ logger.info(f"Downloading weights from HF {model_name_or_path}")
141
+ weights_files = file_utils.download_model_weights_from_hf(
142
+ model_name_or_path, download_dir, HF_WEIGHTS_FORMAT)
143
+ else:
144
+ raise ValueError(
145
+ f"{model_name_or_path} must be a local directory, or a Huggingface model id."
146
+ )
147
+
148
+ if not weights_files:
149
+ raise RuntimeError(
150
+ f"Cannot find any {HF_WEIGHTS_FORMAT} files in {model_name_or_path}."
151
+ )
152
+
153
+ weights_files.sort()
154
+ return weights_files
155
+
156
+
157
+ def model_weights_single_file_generator(
158
+ weights_file: str,
159
+ framework: str,
160
+ filter_regex: Optional[str] = None,
161
+ ) -> Generator[tuple, None, None]:
162
+ logger.info(f"Loading weights from {weights_file}")
163
+ # NOTE: We enforce loading tensors on CPU here.
164
+ # Because otherwise the tensor will be loaded on TPU:0 by default,
165
+ # although the tensor would eventually be sharded across multiple TPUs,
166
+ # it would lead to OOM on TPU:0 for large models.
167
+ with jax.default_device(jax.devices("cpu")[0]):
168
+ with safe_open(weights_file, framework=framework) as f:
169
+ for name in f.keys():
170
+ if filter_regex is not None and not re.match(
171
+ filter_regex, name):
172
+ continue
173
+ weight_tensor = f.get_tensor(name)
174
+ yield name, weight_tensor
175
+
176
+
177
+ def get_param(params: nnx.State, path: str) -> nnx.State:
178
+ keys = path.split(".")
179
+ plevel = params
180
+ for key in keys:
181
+ if key.isdigit():
182
+ plevel = plevel[int(key)]
183
+ else:
184
+ if key in plevel:
185
+ plevel = plevel[key]
186
+ else:
187
+ raise ValueError(f"{path} is not a valid param path")
188
+ return plevel
189
+
190
+
191
+ def get_param_and_sharding(params: nnx.State, shardings: Any,
192
+ path: str) -> tuple[nnx.State, nnx.State]:
193
+ keys = path.split(".")
194
+ plevel = params
195
+ slevel = shardings
196
+ for key in keys:
197
+ if key.isdigit():
198
+ plevel = plevel[int(key)]
199
+ slevel = slevel[int(key)]
200
+ else:
201
+ if key in plevel:
202
+ plevel = plevel[key]
203
+ slevel = slevel[key]
204
+ else:
205
+ raise ValueError(f"{path} is not a valid param path")
206
+ return plevel, slevel.value
207
+
208
+
209
+ def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
210
+ # Single device sharding requires this special handling
211
+ # to avoid the recursive jit error.
212
+ if math.prod(mesh.axis_sizes) == 1:
213
+ return jax.device_put(x, mesh.devices.flatten()[0])
214
+
215
+ if isinstance(shardings, tuple):
216
+ return jax.device_put(x, NamedSharding(mesh, P(*shardings)))
217
+ else:
218
+ return jax.device_put(x, shardings)
219
+
220
+
221
+ def get_default_maps(model_config, mesh: Mesh,
222
+ name_map: dict[str, str]) -> MetadataMap:
223
+ """Load weights from one model weights file to the model, run on single thread."""
224
+ sharding_size = mesh.shape["model"]
225
+
226
+ hf_config = model_config.hf_config
227
+
228
+ num_heads = hf_config.num_attention_heads
229
+ num_kv_heads = hf_config.num_key_value_heads
230
+ hidden_size = model_config.get_hidden_size()
231
+
232
+ # Pad head_dim for kernel performance.
233
+ head_dim_original = model_config.get_head_size()
234
+
235
+ reshape_keys: dict[str, tuple[int, ...]] = {
236
+ "q_proj": (num_heads, head_dim_original, hidden_size),
237
+ "k_proj": (num_kv_heads, head_dim_original, hidden_size),
238
+ "v_proj": (num_kv_heads, head_dim_original, hidden_size),
239
+ "o_proj": (hidden_size, num_heads, head_dim_original),
240
+ }
241
+ bias_reshape_keys: dict[str, tuple[int, ...]] = {
242
+ "q_proj.bias": (num_heads, head_dim_original),
243
+ "k_proj.bias": (num_kv_heads, head_dim_original),
244
+ "v_proj.bias": (num_kv_heads, head_dim_original)
245
+ }
246
+ transpose_keys: dict[str, tuple[int, ...]] = {
247
+ "lm_head": (1, 0),
248
+ "fc": (1, 0),
249
+ "gate_proj": (1, 0),
250
+ "up_proj": (1, 0),
251
+ "down_proj": (1, 0),
252
+ "q_proj": (2, 0, 1),
253
+ "k_proj": (2, 0, 1),
254
+ "v_proj": (2, 0, 1),
255
+ "o_proj": (1, 2, 0),
256
+ }
257
+
258
+ # # get vision config
259
+ if model_config.is_multimodal_model:
260
+ # TODO: Wenlong: Do not consider padding for now
261
+ transpose_keys.update({
262
+ "attn.proj": (1, 0),
263
+ "attn.qkv": (1, 0),
264
+ "visual.merger.mlp": (1, 0),
265
+ "visual.patch_embed.proj": (2, 3, 4, 1, 0),
266
+ })
267
+
268
+ # key: (padding_dim, padding_size)
269
+ pad_keys: dict[str, tuple[int, ...]] = {
270
+ "q_proj": (1, sharding_size // num_heads),
271
+ "k_proj": (1, sharding_size // num_kv_heads),
272
+ "v_proj": (1, sharding_size // num_kv_heads),
273
+ "o_proj": (0, sharding_size // num_heads),
274
+ }
275
+ bias_pad_keys: dict[str, tuple[int, ...]] = {
276
+ "q_proj.bias": (0, sharding_size // num_heads),
277
+ "k_proj.bias": (0, sharding_size // num_kv_heads),
278
+ "v_proj.bias": (0, sharding_size // num_kv_heads),
279
+ }
280
+
281
+ return MetadataMap(name_map=name_map,
282
+ reshape_map=reshape_keys,
283
+ bias_reshape_map=bias_reshape_keys,
284
+ transpose_map=transpose_keys,
285
+ pad_map=pad_keys,
286
+ bias_pad_map=bias_pad_keys)
287
+
288
+
289
+ def _load_and_shard_weight(vllm_config,
290
+ params: nnx.State,
291
+ shardings: Any,
292
+ metadata_map: MetadataMap,
293
+ mesh: Mesh,
294
+ hf_key: str,
295
+ hf_weight: jax.Array,
296
+ keep_original_dtype_keys_regex: list[str]
297
+ | None = None,
298
+ pp_missing_layers: list[str] | None = None):
299
+ name_map = metadata_map.name_map
300
+ reshape_keys = metadata_map.reshape_map
301
+ bias_reshape_keys = metadata_map.bias_reshape_map
302
+ transpose_keys = metadata_map.transpose_map
303
+ pad_keys = metadata_map.pad_map
304
+ bias_pad_keys = metadata_map.bias_pad_map
305
+
306
+ shard = functools.partial(shard_put, mesh=mesh)
307
+
308
+ model_config = vllm_config.model_config
309
+
310
+ # Pad head_dim for kernel performance.
311
+ head_dim_original = model_config.get_head_size()
312
+ head_dim = utils.get_padded_head_dim(head_dim_original)
313
+ head_dim_pad = head_dim - head_dim_original
314
+
315
+ # Check if the key should retain its original dtype
316
+ keep_original_dtype = False
317
+ if keep_original_dtype_keys_regex:
318
+ for pattern in keep_original_dtype_keys_regex:
319
+ if re.match(pattern, hf_key):
320
+ keep_original_dtype = True
321
+ break
322
+
323
+ # Converting to config's dtype
324
+ if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
325
+ logger.warning(
326
+ f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
327
+ )
328
+ hf_weight = hf_weight.astype(model_config.dtype)
329
+
330
+ if hf_key.endswith(".weight"):
331
+ hf_key = hf_key.removesuffix(".weight")
332
+
333
+ # Find the corresponding model key using the HF key
334
+ if "layers" in hf_key:
335
+ layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
336
+ layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
337
+ model_key = name_map[layer_key]
338
+ model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
339
+ elif "blocks" in hf_key:
340
+ layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
341
+ layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
342
+ model_key = name_map[layer_key]
343
+ model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
344
+ else:
345
+ if hf_key not in name_map and hf_key == "lm_head":
346
+ logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings")
347
+ return
348
+ if hf_key not in name_map and "t2d" in hf_key:
349
+ logger.warning(
350
+ f"Skip loading {hf_key} as it's not used in eagle-3 for now")
351
+ return
352
+ model_key = name_map.get(hf_key, hf_key)
353
+
354
+ if pp_missing_layers and _is_pp_missing_layer(hf_key, pp_missing_layers):
355
+ logger.warning(
356
+ f"Skip loading {hf_key} as it doesn't belong to this PP stage.")
357
+ return
358
+ model_weight, model_sharding = get_param_and_sharding(
359
+ params, shardings, model_key)
360
+
361
+ logger.debug(
362
+ "before transform | "
363
+ f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
364
+ )
365
+
366
+ if hf_key.endswith(".bias"):
367
+ for key in bias_reshape_keys:
368
+ if key in hf_key:
369
+ hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
370
+ if head_dim_pad > 0:
371
+ hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad)))
372
+ break
373
+ else:
374
+ for key in reshape_keys:
375
+ if key in hf_key:
376
+ hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
377
+ if head_dim_pad > 0:
378
+ if "o_proj" in key:
379
+ hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
380
+ (0, head_dim_pad)))
381
+ else:
382
+ hf_weight = jnp.pad(hf_weight,
383
+ ((0, 0), (0, head_dim_pad),
384
+ (0, 0)))
385
+ break
386
+ for key in transpose_keys:
387
+ if key in hf_key:
388
+ hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
389
+ break
390
+
391
+ # Pad num-kv-heads
392
+ if hf_key.endswith(".bias"):
393
+ for key, value in bias_pad_keys.items():
394
+ dim = value[0]
395
+ dim_size = value[1]
396
+ if key in hf_key and dim_size != 0:
397
+ hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
398
+ break
399
+ else:
400
+ for key, value in pad_keys.items():
401
+ dim = value[0]
402
+ dim_size = value[1]
403
+ if key in hf_key and dim_size != 0:
404
+ hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
405
+ break
406
+
407
+ logger.debug(
408
+ "after transform | "
409
+ f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
410
+ )
411
+
412
+ if head_dim_pad == 0:
413
+ assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
414
+
415
+ # Update the model weight
416
+ spec = model_weight.sharding.spec if isinstance(
417
+ model_weight.sharding, NamedSharding) else model_weight.sharding
418
+ model_weight.value = shard(hf_weight, spec)
419
+
420
+
421
+ def _is_pp_missing_layer(hf_key: str, pp_missing_layers: list[str]) -> bool:
422
+ has_digit = any(char.isdigit() for char in hf_key)
423
+ # add the suffix after digits to avoid it matches "layers.10" with "layers.1"
424
+ suffix = "." if has_digit else ""
425
+ return any(f'{pp_missing_layer}{suffix}' in hf_key
426
+ for pp_missing_layer in pp_missing_layers)
427
+
428
+
429
+ def _load_hf_weights_on_thread(
430
+ vllm_config: VllmConfig,
431
+ params: nnx.State,
432
+ metadata_map: "MetadataMap",
433
+ mesh: Mesh,
434
+ weights_file: str,
435
+ filter_regex: Optional[str] = None,
436
+ keep_original_dtype_keys_regex: Optional[list[str]] = None,
437
+ pp_missing_layers: list[str] | None = None,
438
+ ):
439
+ """Loads weights from a single weights file."""
440
+ try:
441
+ shardings = nnx.get_named_sharding(params, mesh)
442
+ except TypeError:
443
+ shardings = params
444
+
445
+ for hf_key, hf_weight in model_weights_single_file_generator(
446
+ weights_file, framework="flax", filter_regex=filter_regex):
447
+ _load_and_shard_weight(
448
+ vllm_config,
449
+ params,
450
+ shardings,
451
+ metadata_map,
452
+ mesh,
453
+ hf_key,
454
+ hf_weight,
455
+ keep_original_dtype_keys_regex,
456
+ pp_missing_layers,
457
+ )
458
+
459
+
460
+ def load_hf_weights(
461
+ vllm_config: VllmConfig,
462
+ model: nnx.Module,
463
+ metadata_map: "MetadataMap",
464
+ mesh: Mesh,
465
+ filter_regex: Optional[str] = None,
466
+ is_draft_model: bool = False,
467
+ keep_original_dtype_keys_regex: Optional[list[str]] = None,
468
+ pp_missing_layers: list[str] | None = None,
469
+ ):
470
+ """Load weights into a JAX model from either an iterator or files."""
471
+ params = nnx.state(model)
472
+ try:
473
+ shardings = nnx.get_named_sharding(params, mesh)
474
+ except TypeError:
475
+ shardings = params
476
+ weights_iterator = None
477
+ if hasattr(vllm_config.model_config, "model_weights_iterator"):
478
+ weights_iterator = vllm_config.model_config.model_weights_iterator
479
+ env = torchax.default_env()
480
+ # The weights_iterator is used in RunAI model streamer integration.
481
+ if weights_iterator is not None:
482
+ for hf_key, hf_weight in weights_iterator:
483
+ if filter_regex and not re.match(filter_regex, hf_key):
484
+ continue
485
+
486
+ # Since the weights_iterator yields Pytorch tensors (torch.Tensor),
487
+ # we need to convert them to JAX arrays (jax.Array).
488
+ hf_weight_jax = env.t2j_copy(hf_weight)
489
+
490
+ _load_and_shard_weight(
491
+ vllm_config,
492
+ params,
493
+ shardings,
494
+ metadata_map,
495
+ mesh,
496
+ hf_key,
497
+ hf_weight_jax,
498
+ keep_original_dtype_keys_regex,
499
+ pp_missing_layers=pp_missing_layers,
500
+ )
501
+ else:
502
+ # File-based path (multi-threaded)
503
+ if is_draft_model:
504
+ model_path = vllm_config.speculative_config.draft_model_config.model
505
+ else:
506
+ model_path = vllm_config.model_config.model
507
+ weights_files = get_model_weights_files(
508
+ model_path, vllm_config.load_config.download_dir)
509
+ max_workers = min(64, len(weights_files))
510
+ # NOTE(xiang): Disable multi-threading mode if running on multi-host.
511
+ # Because multi-threading would cause different JAX processes to load
512
+ # different weights at the same time.
513
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
514
+ max_workers = 1
515
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
516
+ futures = [
517
+ executor.submit(
518
+ _load_hf_weights_on_thread,
519
+ vllm_config,
520
+ params,
521
+ metadata_map,
522
+ mesh,
523
+ weights_file,
524
+ filter_regex=filter_regex,
525
+ keep_original_dtype_keys_regex=
526
+ keep_original_dtype_keys_regex,
527
+ pp_missing_layers=pp_missing_layers,
528
+ ) for weights_file in weights_files
529
+ ]
530
+ for future in futures:
531
+ future.result()
532
+
533
+ check_all_loaded(params)
534
+ nnx.update(model, params)
535
+
536
+
537
+ def check_all_loaded(params: nnx.State):
538
+
539
+ def _check(x: Any):
540
+ if isinstance(x, nnx.Param) and isinstance(x.value,
541
+ jax.ShapeDtypeStruct):
542
+ raise ValueError(f"The param does not load weights: {x}")
543
+
544
+ jax.tree.map(_check, params)
545
+
546
+
547
+ def build_flat_dict(flat_state, mappings):
548
+ """Build a new flat dictionary from the flat state using the provided mappings."""
549
+ new_flat_dict = {}
550
+ for keys, v in flat_state:
551
+ path = '.'.join(str(key) for key in keys)
552
+ mapped = False
553
+ for src, (tgt, sharding) in mappings.items():
554
+ regex = "^" + re.escape(tgt).replace("\\.\\*", r"\.(\d+)") + "$"
555
+ matched = re.match(regex, path)
556
+ if matched:
557
+ # Extract wildcards if any
558
+ wildcards = matched.groups()
559
+ src_parts = []
560
+ wc_index = 0
561
+ for part in src.split("."):
562
+ if part == "*":
563
+ src_parts.append(wildcards[wc_index])
564
+ wc_index += 1
565
+ else:
566
+ src_parts.append(part)
567
+ actual_src = ".".join(src_parts)
568
+ new_flat_dict[actual_src] = v, sharding
569
+ mapped = True
570
+ break
571
+ if not mapped:
572
+ logger.info(f"!!! No mapping for flat state: {keys}")
573
+ return new_flat_dict
574
+
575
+
576
+ def transfer_state_with_mappings(src_state,
577
+ tgt_state,
578
+ mappings,
579
+ transpose_keys=None,
580
+ shard=None):
581
+ """Transfer state from src_state to tgt_state using the provided mappings."""
582
+ src_flat = src_state.flat_state()
583
+ tgt_flat = tgt_state.flat_state()
584
+
585
+ new_src_dict = build_flat_dict(tgt_flat, mappings)
586
+ logger.info(f"{mappings=}")
587
+ logger.info(f"{transpose_keys=}")
588
+ for src_keys, v in src_flat:
589
+ flattened_src_keys = '.'.join(str(k) for k in src_keys)
590
+ new_v = jnp.copy(v.value)
591
+ logger.info(
592
+ f"Processing source key: {flattened_src_keys} and value: {new_v.shape} {new_v.dtype}"
593
+ )
594
+ if flattened_src_keys not in new_src_dict:
595
+ logger.info(f"!!! No mapping for source key: {flattened_src_keys}")
596
+ continue
597
+ sharding = new_src_dict[flattened_src_keys][1]
598
+
599
+ # E.g. layers.*.attn.k_proj.w, layers.*.attn.k_proj.w_lora_a
600
+ # E.g. layers.*.mlp.down_proj.kernel, layers.*.mlp.down_proj.kernel_lora_a
601
+ if transpose_keys is not None \
602
+ and ((src_keys[-1] in transpose_keys) and ('lora' not in src_keys[-1])):
603
+ v_maybe_t = jnp.transpose(new_v, transpose_keys[src_keys[-1]])
604
+ else:
605
+ v_maybe_t = new_v
606
+
607
+ to_update_value = new_src_dict[flattened_src_keys][0].value
608
+ assert to_update_value.shape == v_maybe_t.shape, \
609
+ f"Shape mismatch for {flattened_src_keys}: {to_update_value.shape} vs {v_maybe_t.shape}"
610
+
611
+ if to_update_value.dtype != v_maybe_t.dtype:
612
+ logger.info(
613
+ f"Type mismatch between external model and vLLM model. Converting {v_maybe_t.dtype=} to {to_update_value.dtype=}"
614
+ )
615
+ v_maybe_t = v_maybe_t.astype(to_update_value.dtype)
616
+
617
+ new_src_dict[flattened_src_keys][0].value = shard(
618
+ v_maybe_t, sharding) if shard else v_maybe_t
619
+
620
+ tgt_state = tgt_state.from_flat_path(tgt_flat)
621
+ return tgt_state
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.