tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """Utilities for downloading model weights from HuggingFace."""
2
15
 
3
16
  import functools
@@ -13,10 +26,12 @@ from typing import Any, Optional
13
26
  import jax
14
27
  import jax.numpy as jnp
15
28
  import torch
29
+ import torchax
16
30
  from flax import nnx
17
31
  from jax.sharding import Mesh, NamedSharding
18
32
  from jax.sharding import PartitionSpec as P
19
33
  from safetensors import safe_open
34
+ from vllm.config import VllmConfig
20
35
 
21
36
  from tpu_inference import envs, utils
22
37
  from tpu_inference.logger import init_logger
@@ -65,7 +80,13 @@ def transpose_params(param_key: str, param_tensor: jax.Array, transpose_map):
65
80
  def reshape_params(param_key: str, param_tensor: jax.Array, shape_map):
66
81
  for key, new_shape in shape_map.items():
67
82
  if key in param_key:
68
- return jnp.reshape(param_tensor, new_shape)
83
+ try:
84
+ #TODO:(gpolovets) Add validation on whether reshape preserves data layout.
85
+ return jnp.reshape(param_tensor, new_shape)
86
+ except TypeError:
87
+ raise TypeError(
88
+ f"Cannot reshape for key={key}, new_shape={new_shape}, param_shape={param_tensor.shape}"
89
+ )
69
90
  return param_tensor # Base case / no-op
70
91
 
71
92
 
@@ -265,15 +286,16 @@ def get_default_maps(model_config, mesh: Mesh,
265
286
  bias_pad_map=bias_pad_keys)
266
287
 
267
288
 
268
- def _load_hf_weights_on_thread(vllm_config,
269
- params: nnx.State,
270
- metadata_map: MetadataMap,
271
- mesh: Mesh,
272
- weights_file: str,
273
- filter_regex: str | None = None,
274
- keep_original_dtype_keys_regex: list[str]
275
- | None = None,
276
- exclude_regex: list[str] | None = None):
289
+ def _load_and_shard_weight(vllm_config,
290
+ params: nnx.State,
291
+ shardings: Any,
292
+ metadata_map: MetadataMap,
293
+ mesh: Mesh,
294
+ hf_key: str,
295
+ hf_weight: jax.Array,
296
+ keep_original_dtype_keys_regex: list[str]
297
+ | None = None,
298
+ pp_missing_layers: list[str] | None = None):
277
299
  name_map = metadata_map.name_map
278
300
  reshape_keys = metadata_map.reshape_map
279
301
  bias_reshape_keys = metadata_map.bias_reshape_map
@@ -290,6 +312,131 @@ def _load_hf_weights_on_thread(vllm_config,
290
312
  head_dim = utils.get_padded_head_dim(head_dim_original)
291
313
  head_dim_pad = head_dim - head_dim_original
292
314
 
315
+ # Check if the key should retain its original dtype
316
+ keep_original_dtype = False
317
+ if keep_original_dtype_keys_regex:
318
+ for pattern in keep_original_dtype_keys_regex:
319
+ if re.match(pattern, hf_key):
320
+ keep_original_dtype = True
321
+ break
322
+
323
+ # Converting to config's dtype
324
+ if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
325
+ logger.warning(
326
+ f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
327
+ )
328
+ hf_weight = hf_weight.astype(model_config.dtype)
329
+
330
+ if hf_key.endswith(".weight"):
331
+ hf_key = hf_key.removesuffix(".weight")
332
+
333
+ # Find the corresponding model key using the HF key
334
+ if "layers" in hf_key:
335
+ layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
336
+ layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
337
+ model_key = name_map[layer_key]
338
+ model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
339
+ elif "blocks" in hf_key:
340
+ layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
341
+ layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
342
+ model_key = name_map[layer_key]
343
+ model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
344
+ else:
345
+ if hf_key not in name_map and hf_key == "lm_head":
346
+ logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings")
347
+ return
348
+ if hf_key not in name_map and "t2d" in hf_key:
349
+ logger.warning(
350
+ f"Skip loading {hf_key} as it's not used in eagle-3 for now")
351
+ return
352
+ model_key = name_map.get(hf_key, hf_key)
353
+
354
+ if pp_missing_layers and _is_pp_missing_layer(hf_key, pp_missing_layers):
355
+ logger.warning(
356
+ f"Skip loading {hf_key} as it doesn't belong to this PP stage.")
357
+ return
358
+ model_weight, model_sharding = get_param_and_sharding(
359
+ params, shardings, model_key)
360
+
361
+ logger.debug(
362
+ "before transform | "
363
+ f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
364
+ )
365
+
366
+ if hf_key.endswith(".bias"):
367
+ for key in bias_reshape_keys:
368
+ if key in hf_key:
369
+ hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
370
+ if head_dim_pad > 0:
371
+ hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad)))
372
+ break
373
+ else:
374
+ for key in reshape_keys:
375
+ if key in hf_key:
376
+ hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
377
+ if head_dim_pad > 0:
378
+ if "o_proj" in key:
379
+ hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
380
+ (0, head_dim_pad)))
381
+ else:
382
+ hf_weight = jnp.pad(hf_weight,
383
+ ((0, 0), (0, head_dim_pad),
384
+ (0, 0)))
385
+ break
386
+ for key in transpose_keys:
387
+ if key in hf_key:
388
+ hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
389
+ break
390
+
391
+ # Pad num-kv-heads
392
+ if hf_key.endswith(".bias"):
393
+ for key, value in bias_pad_keys.items():
394
+ dim = value[0]
395
+ dim_size = value[1]
396
+ if key in hf_key and dim_size != 0:
397
+ hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
398
+ break
399
+ else:
400
+ for key, value in pad_keys.items():
401
+ dim = value[0]
402
+ dim_size = value[1]
403
+ if key in hf_key and dim_size != 0:
404
+ hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
405
+ break
406
+
407
+ logger.debug(
408
+ "after transform | "
409
+ f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
410
+ )
411
+
412
+ if head_dim_pad == 0:
413
+ assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
414
+
415
+ # Update the model weight
416
+ spec = model_weight.sharding.spec if isinstance(
417
+ model_weight.sharding, NamedSharding) else model_weight.sharding
418
+ model_weight.value = shard(hf_weight, spec)
419
+
420
+
421
+ def _is_pp_missing_layer(hf_key: str, pp_missing_layers: list[str]) -> bool:
422
+ has_digit = any(char.isdigit() for char in hf_key)
423
+ # add the suffix after digits to avoid it matches "layers.10" with "layers.1"
424
+ suffix = "." if has_digit else ""
425
+ return any(f'{pp_missing_layer}{suffix}' in hf_key
426
+ for pp_missing_layer in pp_missing_layers)
427
+
428
+
429
+ def _load_hf_weights_on_thread(
430
+ vllm_config: VllmConfig,
431
+ params: nnx.State,
432
+ metadata_map: "MetadataMap",
433
+ mesh: Mesh,
434
+ weights_file: str,
435
+ filter_regex: Optional[str] = None,
436
+ keep_original_dtype_keys_regex: Optional[list[str]] = None,
437
+ pp_missing_layers: list[str] | None = None,
438
+ ):
439
+ """Loads weights from a single weights file."""
293
440
  try:
294
441
  shardings = nnx.get_named_sharding(params, mesh)
295
442
  except TypeError:
@@ -297,160 +444,92 @@ def _load_hf_weights_on_thread(vllm_config,
297
444
 
298
445
  for hf_key, hf_weight in model_weights_single_file_generator(
299
446
  weights_file, framework="flax", filter_regex=filter_regex):
447
+ _load_and_shard_weight(
448
+ vllm_config,
449
+ params,
450
+ shardings,
451
+ metadata_map,
452
+ mesh,
453
+ hf_key,
454
+ hf_weight,
455
+ keep_original_dtype_keys_regex,
456
+ pp_missing_layers,
457
+ )
300
458
 
301
- # Check if the key should be excluded
302
- if exclude_regex:
303
- should_exclude = False
304
- for pattern in exclude_regex:
305
- if re.search(pattern, hf_key):
306
- logger.info(
307
- f"Excluding {hf_key} based on pattern {pattern}")
308
- should_exclude = True
309
- break
310
- if should_exclude:
311
- continue
312
-
313
- # Check if the key should retain its original dtype
314
- keep_original_dtype = False
315
- if keep_original_dtype_keys_regex:
316
- for pattern in keep_original_dtype_keys_regex:
317
- if re.match(pattern, hf_key):
318
- keep_original_dtype = True
319
- break
320
459
 
321
- # Converting to config's dtype
322
- if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
323
- logger.warning(
324
- f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
325
- )
326
- hf_weight = hf_weight.astype(model_config.dtype)
327
-
328
- if hf_key.endswith(".weight"):
329
- hf_key = hf_key.removesuffix(".weight")
330
-
331
- # Find the corresponding model key using the HF key
332
- if "layers" in hf_key:
333
- layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
334
- layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
335
- model_key = name_map[layer_key]
336
- model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
337
- elif "blocks" in hf_key:
338
- layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
339
- layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
340
- model_key = name_map[layer_key]
341
- model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
342
- else:
343
- if hf_key not in name_map and hf_key == "lm_head":
344
- logger.warning(
345
- f"Skip loading {hf_key} due to tie_word_embeddings")
346
- continue
347
- if hf_key not in name_map and "t2d" in hf_key:
348
- logger.warning(
349
- f"Skip loading {hf_key} as it's not used in eagle-3 for now"
350
- )
460
+ def load_hf_weights(
461
+ vllm_config: VllmConfig,
462
+ model: nnx.Module,
463
+ metadata_map: "MetadataMap",
464
+ mesh: Mesh,
465
+ filter_regex: Optional[str] = None,
466
+ is_draft_model: bool = False,
467
+ keep_original_dtype_keys_regex: Optional[list[str]] = None,
468
+ pp_missing_layers: list[str] | None = None,
469
+ ):
470
+ """Load weights into a JAX model from either an iterator or files."""
471
+ params = nnx.state(model)
472
+ try:
473
+ shardings = nnx.get_named_sharding(params, mesh)
474
+ except TypeError:
475
+ shardings = params
476
+ weights_iterator = None
477
+ if hasattr(vllm_config.model_config, "model_weights_iterator"):
478
+ weights_iterator = vllm_config.model_config.model_weights_iterator
479
+ env = torchax.default_env()
480
+ # The weights_iterator is used in RunAI model streamer integration.
481
+ if weights_iterator is not None:
482
+ for hf_key, hf_weight in weights_iterator:
483
+ if filter_regex and not re.match(filter_regex, hf_key):
351
484
  continue
352
- model_key = name_map.get(hf_key, hf_key)
353
- model_weight, model_sharding = get_param_and_sharding(
354
- params, shardings, model_key)
355
-
356
- logger.debug(
357
- "before transform | "
358
- f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
359
- )
360
485
 
361
- if hf_key.endswith(".bias"):
362
- for key in bias_reshape_keys:
363
- if key in hf_key:
364
- hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
365
- if head_dim_pad > 0:
366
- hf_weight = jnp.pad(hf_weight,
367
- ((0, 0), (0, head_dim_pad)))
368
- break
369
- else:
370
- for key in reshape_keys:
371
- if key in hf_key:
372
- hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
373
- if head_dim_pad > 0:
374
- if "o_proj" in key:
375
- hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
376
- (0, head_dim_pad)))
377
- else:
378
- hf_weight = jnp.pad(hf_weight,
379
- ((0, 0), (0, head_dim_pad),
380
- (0, 0)))
381
- break
382
- for key in transpose_keys:
383
- if key in hf_key:
384
- hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
385
- break
386
-
387
- # Pad num-kv-heads
388
- if hf_key.endswith(".bias"):
389
- for key, value in bias_pad_keys.items():
390
- dim = value[0]
391
- dim_size = value[1]
392
- if key in hf_key and dim_size != 0:
393
- hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
394
- break
395
- else:
396
- for key, value in pad_keys.items():
397
- dim = value[0]
398
- dim_size = value[1]
399
- if key in hf_key and dim_size != 0:
400
- hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
401
- break
402
-
403
- logger.debug(
404
- "after transform | "
405
- f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
406
- )
486
+ # Since the weights_iterator yields Pytorch tensors (torch.Tensor),
487
+ # we need to convert them to JAX arrays (jax.Array).
488
+ hf_weight_jax = env.t2j_copy(hf_weight)
407
489
 
408
- if head_dim_pad == 0:
409
- assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
410
-
411
- # Update the model weight
412
- spec = model_weight.sharding.spec if isinstance(
413
- model_weight.sharding, NamedSharding) else model_weight.sharding
414
- model_weight.value = shard(hf_weight, spec)
415
-
416
-
417
- def load_hf_weights(vllm_config,
418
- model: nnx.Module,
419
- metadata_map: MetadataMap,
420
- mesh: Mesh,
421
- filter_regex: str | None = None,
422
- is_draft_model: bool = False,
423
- keep_original_dtype_keys_regex: list[str] | None = None,
424
- exclude_regex: list[str] | None = None):
425
- """Load weights from all model weights files to the model, run in multi threads."""
426
- if is_draft_model:
427
- model_path = vllm_config.speculative_config.draft_model_config.model
428
- else:
429
- model_path = vllm_config.model_config.model
430
- weights_files = get_model_weights_files(
431
- model_path, vllm_config.load_config.download_dir)
432
- params = nnx.state(model)
433
- max_workers = min(64, len(weights_files))
434
- # NOTE(xiang): Disable multi-threading mode if running on multi-host.
435
- # Because multi-threading would cause different JAX processes to load
436
- # different weights at the same time.
437
- if envs.TPU_MULTIHOST_BACKEND == "ray":
438
- max_workers = 1
439
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
440
- futures = [
441
- executor.submit(
442
- _load_hf_weights_on_thread,
490
+ _load_and_shard_weight(
443
491
  vllm_config,
444
492
  params,
493
+ shardings,
445
494
  metadata_map,
446
495
  mesh,
447
- weights_file,
448
- filter_regex=filter_regex,
449
- keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
450
- exclude_regex=exclude_regex) for weights_file in weights_files
451
- ]
452
- for future in futures:
453
- future.result()
496
+ hf_key,
497
+ hf_weight_jax,
498
+ keep_original_dtype_keys_regex,
499
+ pp_missing_layers=pp_missing_layers,
500
+ )
501
+ else:
502
+ # File-based path (multi-threaded)
503
+ if is_draft_model:
504
+ model_path = vllm_config.speculative_config.draft_model_config.model
505
+ else:
506
+ model_path = vllm_config.model_config.model
507
+ weights_files = get_model_weights_files(
508
+ model_path, vllm_config.load_config.download_dir)
509
+ max_workers = min(64, len(weights_files))
510
+ # NOTE(xiang): Disable multi-threading mode if running on multi-host.
511
+ # Because multi-threading would cause different JAX processes to load
512
+ # different weights at the same time.
513
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
514
+ max_workers = 1
515
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
516
+ futures = [
517
+ executor.submit(
518
+ _load_hf_weights_on_thread,
519
+ vllm_config,
520
+ params,
521
+ metadata_map,
522
+ mesh,
523
+ weights_file,
524
+ filter_regex=filter_regex,
525
+ keep_original_dtype_keys_regex=
526
+ keep_original_dtype_keys_regex,
527
+ pp_missing_layers=pp_missing_layers,
528
+ ) for weights_file in weights_files
529
+ ]
530
+ for future in futures:
531
+ future.result()
532
+
454
533
  check_all_loaded(params)
455
534
  nnx.update(model, params)
456
535
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
2
16
  import functools
3
17
  from collections.abc import Sequence
@@ -9,6 +23,7 @@ import jax
9
23
  import torch
10
24
  import torch.nn
11
25
  import torchax
26
+ import vllm.envs as vllm_envs
12
27
  from flax.typing import PRNGKey
13
28
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
14
29
  from torchax.interop import jax_view, torch_view
@@ -22,8 +37,10 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
22
37
  from vllm.sequence import IntermediateTensors
23
38
 
24
39
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
+ from tpu_inference.layers.common.sharding import ShardingAxisName
41
+ from tpu_inference.layers.vllm.process_weights.cleanup_sharding import \
42
+ shard_model_to_tpu
25
43
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
26
- from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
27
44
  from tpu_inference.logger import init_logger
28
45
  from tpu_inference.models.jax.jax_intermediate_tensor import \
29
46
  JaxIntermediateTensors
@@ -118,9 +135,16 @@ class VllmModelWrapper:
118
135
  "torch._sync",
119
136
  return_value=None) if use_random_weights else nullcontext()
120
137
 
138
+ # By default load weights to the CPU device first. If we are running
139
+ # under Pathways, this would cause weights to be loaded on a CPU-only
140
+ # node, so we'll need to remove this context.
141
+ jax_context = jax.default_device(
142
+ jax.devices("cpu")
143
+ [0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
144
+
121
145
  # Load the vLLM model and wrap it into a new model whose forward
122
146
  # function can calculate the hidden_state and logits.
123
- with load_context, jax.default_device(jax.devices("cpu")[0]):
147
+ with load_context, jax_context:
124
148
  vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
125
149
  lora_manager = None
126
150
  if vllm_config_for_load.lora_config is not None:
@@ -189,7 +213,7 @@ class VllmModelWrapper:
189
213
  kwargs={
190
214
  "input_ids": torch_view(input_ids),
191
215
  "positions": torch_view(input_positions),
192
- "intermediate_tensors": None,
216
+ "intermediate_tensors": intermediate_tensors,
193
217
  "inputs_embeds": None,
194
218
  },
195
219
  tie_weights=False,
@@ -212,8 +236,10 @@ class VllmModelWrapper:
212
236
 
213
237
  @functools.partial(
214
238
  jax.jit,
215
- out_shardings=(NamedSharding(self.mesh,
216
- PartitionSpec(None, "model"))),
239
+ out_shardings=(NamedSharding(
240
+ self.mesh,
241
+ PartitionSpec(ShardingAxisName.MLP_DATA,
242
+ ShardingAxisName.MLP_TENSOR))),
217
243
  )
218
244
  def compute_logits_func(
219
245
  params_and_buffers: Any,
@@ -255,7 +281,6 @@ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
255
281
  vllm_config,
256
282
  device,
257
283
  model.embedding_modules,
258
- model.embedding_padding_modules,
259
284
  )
260
285
  return lora_manager, lora_manager.create_lora_manager(model)
261
286
 
@@ -269,10 +294,9 @@ def replace_set_lora(model):
269
294
  index: int,
270
295
  lora_a: torch.Tensor,
271
296
  lora_b: torch.Tensor,
272
- embeddings_tensor: Optional[torch.Tensor],
273
297
  ):
274
298
  with torchax.default_env():
275
- self._original_set_lora(index, lora_a, lora_b, embeddings_tensor)
299
+ self._original_set_lora(index, lora_a, lora_b)
276
300
 
277
301
  def _tpu_reset_lora(self, index: int):
278
302
  with torchax.default_env():
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from contextlib import contextmanager
2
16
  from dataclasses import dataclass
3
17
  from typing import Dict, List, Optional
@@ -1,2 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  # ruff: noqa
2
16
  from tpu_inference.platforms.tpu_platform import TpuPlatform