tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,643 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import List, Optional, Tuple
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from flax import nnx
21
+ from flax.typing import PRNGKey
22
+ from jax.sharding import Mesh
23
+ from jax.sharding import PartitionSpec as P
24
+ from vllm.config import VllmConfig
25
+
26
+ from tpu_inference.layers.jax.attention.attention import AttentionMetadata
27
+ from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
28
+ from tpu_inference.layers.jax.constants import KVCacheType
29
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
30
+ from tpu_inference.layers.jax.misc import shard_put
31
+ from tpu_inference.layers.jax.moe.moe import MoE, Router
32
+ from tpu_inference.layers.jax.transformer_block import \
33
+ SharedExpertsTransformerBlock
34
+ from tpu_inference.logger import init_logger
35
+ from tpu_inference.models.jax.utils.weight_utils import (
36
+ convert_torch_to_jax_with_view, get_param, model_weights_generator,
37
+ print_param_info, reshape_params, transpose_params)
38
+
39
+ logger = init_logger(__name__)
40
+
41
+
42
+ class Llama4ForCausalLM(nnx.Module):
43
+
44
+ def __init__(self,
45
+ vllm_config: VllmConfig,
46
+ rng: PRNGKey,
47
+ mesh: Mesh,
48
+ force_random_weights: bool = False):
49
+ assert mesh is not None
50
+
51
+ self.vllm_config = vllm_config
52
+ model_config = vllm_config.model_config
53
+ text_config = model_config.hf_config.text_config
54
+
55
+ self.rng = nnx.Rngs(rng)
56
+ self.mesh = mesh
57
+ self.is_verbose = getattr(self.vllm_config.additional_config,
58
+ "is_verbose", False)
59
+
60
+ # Currently the runner will always set a mesh, so the custom default sharding (when
61
+ # no sharding is set in vllm config) doesn't take effect.
62
+ # TODO(fhzhang): figure out whether we need to actually enable this.
63
+ # strategy_dict = {"tensor_parallelism": 4, "expert_parallelism": 2}
64
+
65
+ self.vocab_size = model_config.get_vocab_size()
66
+ self.hidden_size = model_config.get_hidden_size()
67
+
68
+ dtype: jnp.dtype = jnp.bfloat16
69
+
70
+ self.num_layers: int = getattr(text_config, "num_hidden_layers", 48)
71
+
72
+ self.intermediate_size_moe: int = getattr(text_config,
73
+ "intermediate_size", 8192)
74
+ self.intermediate_size_mlp = getattr(text_config,
75
+ "intermediate_size_mlp", 16384)
76
+
77
+ # num_local_experts: uses 16 experts for Llama-4-Scout-17B-16E-Instruct and uses 128 experts Llama-4-Maverick-17B-128E-Instruct.
78
+ # The default value is set to 16 for compatibility with Llama-4-Scout.
79
+ self.num_local_experts: int = getattr(text_config, "num_local_experts",
80
+ 16)
81
+ self.hidden_act: str = getattr(text_config, "hidden_act", "silu")
82
+ self.no_rope_layer_interval = 4
83
+
84
+ # interleave_moe_layer_step has a layer step of 2 to interleave MoE and dense layers for Llama-4-Maverick-17B-128E-Instruct.
85
+ # The default value is set to 1 for compatibility with Llama-4-Scout.
86
+ self.interleave_moe_layer_step = getattr(text_config,
87
+ "interleave_moe_layer_step",
88
+ 1)
89
+
90
+ self.num_attention_heads = getattr(text_config, "num_attention_heads",
91
+ 40)
92
+ self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
93
+ 8)
94
+ self.head_dim = getattr(text_config, "head_dim", 128)
95
+
96
+ self.num_shared_experts = getattr(text_config, "num_experts_per_tok",
97
+ 1)
98
+ self.rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
99
+
100
+ self.rope_scaling = getattr(text_config, "rope_scaling", None)
101
+ if self.rope_scaling:
102
+ self.rope_scaling["scale_factor"] = self.rope_scaling.pop("factor")
103
+
104
+ self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
105
+
106
+ self.embedder = Embedder(vocab_size=self.vocab_size,
107
+ hidden_size=self.hidden_size,
108
+ dtype=dtype,
109
+ vd_sharding=(('data', 'expert', 'model'),
110
+ None),
111
+ rngs=self.rng,
112
+ random_init=force_random_weights)
113
+
114
+ self.layers = []
115
+
116
+ for i in range(self.num_layers):
117
+ # For Llama4-Scout, all layers are MoE layers.
118
+ # This can be adjusted for other variants.
119
+ is_moe_layer = (i + 1) % \
120
+ self.interleave_moe_layer_step == 0
121
+
122
+ # Llama-4-Scout config: It has "no_rope_layers": []
123
+ use_attention_rope = (i + 1) % self.no_rope_layer_interval != 0
124
+
125
+ router = Router(dtype=dtype,
126
+ hidden_size=self.hidden_size,
127
+ num_experts=self.num_local_experts,
128
+ num_experts_per_tok=1,
129
+ router_act="sigmoid",
130
+ rngs=self.rng,
131
+ activation_ffw_td=('data', None),
132
+ ed_sharding=(None, None),
133
+ random_init=force_random_weights)
134
+
135
+ moe_ffw = MoE(
136
+ dtype=dtype,
137
+ num_local_experts=self.num_local_experts,
138
+ apply_expert_weight_before_computation=True,
139
+ hidden_size=self.hidden_size,
140
+ intermediate_size_moe=self.intermediate_size_moe,
141
+ hidden_act=self.hidden_act,
142
+ router=router,
143
+ rngs=self.rng,
144
+ activation_ffw_td=('data', None),
145
+ activation_ffw_ted=('data', 'expert', None),
146
+ edf_sharding=('model', None, None),
147
+ efd_sharding=('model', None, None),
148
+ random_init=force_random_weights) if is_moe_layer else None
149
+
150
+ dense_ffw = DenseFFW(
151
+ dtype=dtype,
152
+ hidden_act=self.hidden_act,
153
+ hidden_size=self.hidden_size,
154
+ intermediate_size=self.intermediate_size_mlp,
155
+ random_init=force_random_weights,
156
+ rngs=self.rng,
157
+ df_sharding=(None, 'model'),
158
+ fd_sharding=('model', None),
159
+ activation_ffw_td=('data', None)) if not is_moe_layer else None
160
+
161
+ attn = Llama4Attention(
162
+ hidden_size=self.hidden_size,
163
+ dtype=dtype,
164
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype,
165
+ num_attention_heads=self.num_attention_heads,
166
+ num_key_value_heads=self.num_key_value_heads,
167
+ head_dim=self.head_dim,
168
+ rope_theta=500000.0,
169
+ # https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json
170
+ rope_scaling=self.rope_scaling,
171
+ rngs=self.rng,
172
+ rope_input_ordering="interleaved",
173
+ temperature_tuning=True,
174
+ temperature_tuning_scale=0.1,
175
+ temperature_tuning_floor_scale=8192,
176
+ use_qk_norm=self.use_qk_norm,
177
+ attention_chunk_size=None if use_attention_rope else 8192,
178
+ mesh=self.mesh,
179
+ random_init=force_random_weights,
180
+ activation_attention_td=('data', 'model'),
181
+ activation_q_td=('data', 'model'),
182
+ query_tnh=P('data', 'model', None),
183
+ keyvalue_skh=P('data', 'model', None),
184
+ activation_attention_out_td=('data', 'model'),
185
+ attn_o_tnh=P('data', 'model', None),
186
+ dnh_sharding=(None, 'model', None),
187
+ dkh_sharding=(None, 'model', None),
188
+ nhd_sharding=('model', None, None),
189
+ )
190
+
191
+ shared_experts = DenseFFW(
192
+ dtype=dtype,
193
+ hidden_act=self.hidden_act,
194
+ hidden_size=self.hidden_size,
195
+ intermediate_size=self.num_shared_experts *
196
+ self.intermediate_size_moe,
197
+ rngs=self.rng,
198
+ random_init=force_random_weights,
199
+ df_sharding=(None, 'model'),
200
+ fd_sharding=('model', None),
201
+ activation_ffw_td=('data', None)) if is_moe_layer else None
202
+
203
+ pre_attention_norm = RMSNorm(
204
+ dims=self.hidden_size,
205
+ random_init=force_random_weights,
206
+ epsilon=self.rms_norm_eps,
207
+ rngs=self.rng,
208
+ with_scale=True,
209
+ dtype=dtype,
210
+ activation_ffw_td=('data', None),
211
+ )
212
+
213
+ pre_mlp_norm = RMSNorm(
214
+ dims=self.hidden_size,
215
+ epsilon=self.rms_norm_eps,
216
+ rngs=self.rng,
217
+ with_scale=True,
218
+ dtype=dtype,
219
+ random_init=force_random_weights,
220
+ activation_ffw_td=('data', None),
221
+ )
222
+
223
+ block = SharedExpertsTransformerBlock(
224
+ moe_ffw=moe_ffw if is_moe_layer else None,
225
+ dense_ffw=dense_ffw if not is_moe_layer else None,
226
+ shared_experts=shared_experts if is_moe_layer else None,
227
+ attn=attn,
228
+ pre_attention_norm=pre_attention_norm,
229
+ pre_mlp_norm=pre_mlp_norm,
230
+ use_attention_rope=use_attention_rope)
231
+ self.layers.append(block)
232
+
233
+ self.final_norm = RMSNorm(
234
+ dims=self.hidden_size,
235
+ epsilon=self.rms_norm_eps,
236
+ rngs=self.rng,
237
+ with_scale=True,
238
+ dtype=dtype,
239
+ random_init=force_random_weights,
240
+ )
241
+
242
+ self.lm_head = LMhead(vocab_size=self.vocab_size,
243
+ hidden_size=self.hidden_size,
244
+ dtype=dtype,
245
+ rngs=self.rng,
246
+ vd_sharding=(('data', 'expert', 'model'), None),
247
+ dv_sharding=(None, ('data', 'expert', 'model')),
248
+ random_init=force_random_weights)
249
+ if self.is_verbose:
250
+ self._print_model_architecture()
251
+
252
+ def _print_model_architecture(self):
253
+ num_display_layers = max(self.interleave_moe_layer_step,
254
+ self.no_rope_layer_interval)
255
+
256
+ logger.info("### Embedding ###")
257
+ nnx.display(self.embedder)
258
+
259
+ logger.info(f"\n### First {num_display_layers} Layers ###")
260
+ # Loop through the slice and display each layer
261
+ for i, layer in enumerate(self.layers[:num_display_layers]):
262
+ logger.info(f"\n--- Layer {i} ---")
263
+ nnx.display(layer)
264
+
265
+ logger.info("\n### LM Head ###")
266
+ nnx.display(self.lm_head)
267
+
268
+ def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
269
+ # NOTE: Since we are using nnx.eval_shape to init the model,
270
+ # we have to pass dynamic arrays here for __call__'s usage.
271
+ self.rng = nnx.Rngs(rng)
272
+
273
+ weight_loader = Llama4WeightLoader(
274
+ vllm_config=self.vllm_config,
275
+ hidden_size=self.hidden_size,
276
+ attn_heads=self.num_attention_heads,
277
+ num_key_value_heads=self.num_key_value_heads,
278
+ attn_head_dim=self.head_dim)
279
+ weight_loader.load_weights(self)
280
+
281
+ def __call__(
282
+ self,
283
+ kv_caches: List[jax.Array],
284
+ input_ids: jax.Array,
285
+ attention_metadata: AttentionMetadata,
286
+ *args,
287
+ ) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]:
288
+ is_prefill = False
289
+ x_TD = self.embedder.encode(input_ids)
290
+
291
+ for (i, block) in enumerate(self.layers):
292
+ kv_cache = kv_caches[i]
293
+ new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
294
+ attention_metadata)
295
+ jax.block_until_ready(x_TD)
296
+ kv_caches[i] = new_kv_cache
297
+
298
+ final_activation_TD = self.final_norm(x_TD)
299
+
300
+ return kv_caches, final_activation_TD, []
301
+
302
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
303
+ logits_TV = jnp.dot(hidden_states,
304
+ self.lm_head.input_embedding_table_DV.value)
305
+ return logits_TV
306
+
307
+
308
+ class Llama4WeightLoader:
309
+
310
+ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
311
+ num_key_value_heads, attn_head_dim):
312
+ self.names_and_weights_generator = model_weights_generator(
313
+ model_name_or_path=vllm_config.model_config.model,
314
+ framework="pt",
315
+ filter_regex="language_model",
316
+ download_dir=vllm_config.load_config.download_dir)
317
+ self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
318
+ False)
319
+ self.interleave_moe_layer_step = getattr(
320
+ vllm_config.model_config.hf_config.text_config,
321
+ "interleave_moe_layer_step", 1)
322
+
323
+ self.quantization_config = getattr(vllm_config.model_config.hf_config,
324
+ "quantization_config", None)
325
+ self.expert_weights_buffer = {}
326
+ self.expert_prefix = "shared_expert."
327
+
328
+ transpose_mappings_to_quantization = {
329
+ "down_proj": (1, 0),
330
+ "gate_proj": (1, 0),
331
+ "up_proj": (1, 0),
332
+ }
333
+
334
+ self._transpose_map = {
335
+ "q_proj": (2, 0, 1),
336
+ "k_proj": (2, 0, 1),
337
+ "v_proj": (2, 0, 1),
338
+ "router": (1, 0),
339
+ f"{self.expert_prefix}down_proj": (1, 0),
340
+ f"{self.expert_prefix}gate_proj": (1, 0),
341
+ f"{self.expert_prefix}up_proj": (1, 0),
342
+ "feed_forward.down_proj": (1, 0),
343
+ "feed_forward.gate_proj": (1, 0),
344
+ "feed_forward.up_proj": (1, 0),
345
+ "o_proj": (1, 2, 0),
346
+ "lm_head": (1, 0),
347
+ }
348
+
349
+ if self.quantization_config and self.expert_prefix:
350
+ self._transpose_map.update(transpose_mappings_to_quantization)
351
+
352
+ self._weight_shape_map = {
353
+ "q_proj": (attn_heads, attn_head_dim, hidden_size),
354
+ "k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
355
+ "v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
356
+ # o_proj is inverted: https://github.com/huggingface/transformers/blob/v4.53.2/src/transformers/models/llama4/modeling_llama4.py#L298
357
+ "o_proj": (hidden_size, attn_heads, attn_head_dim),
358
+ }
359
+
360
+ # Set the mappings from loaded parameter keys to standardized names.\
361
+ # 1. EXPERT_MAPPINGS_FUSED: Used for non-quantized (e.g., BF16) checkpoints.
362
+ # - This format typically comes from standard checkpoints where 'gate' and 'up' projection weights might be combined (FUSED) into a single tensor.
363
+ # - Expert weights are usually stacked, with the expert dimension (E) being the first dimension.
364
+ EXPERT_MAPPINGS_FUSED = {
365
+ "language_model.model.layers.*.feed_forward.experts.down_proj":
366
+ "layers.*.moe_ffw.kernel_down_proj_EFD",
367
+ "language_model.model.layers.*.feed_forward.experts.gate_up_proj":
368
+ "layers.*.moe_ffw.kernel_up_proj_EDF",
369
+ }
370
+
371
+ # 2. EXPERT_MAPPINGS_UNFUSED: Specifically designed for quantized checkpoints (e.g., FP8).
372
+ # - Quantized checkpoints store each expert's weights separately and explicitly separate the 'weight' (quantized value) from the 'weight_scale' (quantization scale).
373
+ # - The mapping captures both the `.weight` and `.weight_scale` components. This allows the loader to aggregate (stack) the individual expert weights and scales.
374
+ EXPERT_MAPPINGS_UNFUSED = {
375
+ "language_model.model.layers.*.feed_forward.experts.*.down_proj.weight":
376
+ "layers.*.moe_ffw.kernel_down_proj_EFD",
377
+ "language_model.model.layers.*.feed_forward.experts.*.down_proj.weight_scale":
378
+ "layers.*.moe_ffw.kernel_down_proj_EFD",
379
+ "language_model.model.layers.*.feed_forward.experts.*.gate_proj.weight":
380
+ "layers.*.moe_ffw.kernel_gating_EDF",
381
+ "language_model.model.layers.*.feed_forward.experts.*.gate_proj.weight_scale":
382
+ "layers.*.moe_ffw.kernel_gating_EDF",
383
+ "language_model.model.layers.*.feed_forward.experts.*.up_proj.weight":
384
+ "layers.*.moe_ffw.kernel_up_proj_EDF",
385
+ "language_model.model.layers.*.feed_forward.experts.*.up_proj.weight_scale":
386
+ "layers.*.moe_ffw.kernel_up_proj_EDF",
387
+ }
388
+
389
+ self._loaded_to_standardized_keys = {
390
+ "language_model.model.embed_tokens.weight":
391
+ "embedder.input_embedding_table_VD",
392
+ "language_model.lm_head.weight":
393
+ "lm_head.input_embedding_table_DV",
394
+ "language_model.model.norm.weight":
395
+ "final_norm.scale",
396
+ "language_model.model.layers.*.input_layernorm.weight":
397
+ "layers.*.pre_attention_norm.scale",
398
+ "language_model.model.layers.*.post_attention_layernorm.weight":
399
+ "layers.*.pre_mlp_norm.scale",
400
+ "language_model.model.layers.*.self_attn.q_proj.weight":
401
+ "layers.*.attn.kernel_q_proj_DNH",
402
+ "language_model.model.layers.*.self_attn.k_proj.weight":
403
+ "layers.*.attn.kernel_k_proj_DKH",
404
+ "language_model.model.layers.*.self_attn.v_proj.weight":
405
+ "layers.*.attn.kernel_v_proj_DKH",
406
+ "language_model.model.layers.*.self_attn.o_proj.weight":
407
+ "layers.*.attn.kernel_o_proj_NHD",
408
+ "language_model.model.layers.*.feed_forward.router.weight":
409
+ "layers.*.moe_ffw.router.kernel_DE",
410
+ # shared experts
411
+ "language_model.model.layers.*.feed_forward.shared_expert.down_proj.weight":
412
+ "layers.*.shared_experts.kernel_down_proj_FD",
413
+ "language_model.model.layers.*.feed_forward.shared_expert.gate_proj.weight":
414
+ "layers.*.shared_experts.kernel_gating_DF",
415
+ "language_model.model.layers.*.feed_forward.shared_expert.up_proj.weight":
416
+ "layers.*.shared_experts.kernel_up_proj_DF",
417
+ # dense layers
418
+ "language_model.model.layers.*.feed_forward.down_proj.weight":
419
+ "layers.*.dense_ffw.kernel_down_proj_FD",
420
+ "language_model.model.layers.*.feed_forward.up_proj.weight":
421
+ "layers.*.dense_ffw.kernel_up_proj_DF",
422
+ "language_model.model.layers.*.feed_forward.gate_proj.weight":
423
+ "layers.*.dense_ffw.kernel_gating_DF",
424
+ }
425
+
426
+ if self.quantization_config is None:
427
+ self._loaded_to_standardized_keys.update(EXPERT_MAPPINGS_FUSED)
428
+ else:
429
+ self._loaded_to_standardized_keys.update(EXPERT_MAPPINGS_UNFUSED)
430
+
431
+ def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
432
+ # Find the corresponding model key using the HF key
433
+ if "layer" in loaded_key:
434
+ layer_num = self._get_layer_num(loaded_key)
435
+ layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
436
+
437
+ expert_match = re.search(r"experts\.(\d+)", layer_key)
438
+ if expert_match:
439
+ # Key for lookup eg: layers.*.feed_forward.experts.*.down_proj.weight
440
+ layer_key = re.sub(r"experts\.\d+", "experts.*", layer_key)
441
+
442
+ mapped_key = self._loaded_to_standardized_keys.get(
443
+ layer_key, loaded_key)
444
+ mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
445
+ mapped_key)
446
+ else:
447
+ mapped_key = self._loaded_to_standardized_keys.get(
448
+ loaded_key, loaded_key)
449
+ return mapped_key
450
+
451
+ def _map_llama4_gate_up_proj(self, model_for_loading: nnx.Module,
452
+ model_params: nnx.State, loaded_name: str,
453
+ loaded_weight: jax.Array):
454
+ """HF's gate_up_proj is a fused tensor of gate and up projections. It needs to be split."""
455
+
456
+ cast_type = jnp.dtype(jnp.bfloat16)
457
+ # loaded_weight is a jax.Array when framework="flax", otherwise it's bfloat16
458
+ if not isinstance(loaded_weight, jax.Array):
459
+ loaded_weight = convert_torch_to_jax_with_view(
460
+ loaded_weight, cast_type)
461
+
462
+ split_weights = jnp.split(loaded_weight, 2, axis=-1)
463
+ layer_num = self._get_layer_num(loaded_name)
464
+
465
+ for split_type in ["gate", "up"]:
466
+ split_loaded_name = loaded_name.replace("gate_up_proj",
467
+ f"{split_type}_proj")
468
+ if split_type == "gate":
469
+ mapped_name = "layers.*.moe_ffw.kernel_gating_EDF"
470
+ loaded_weight = split_weights[0]
471
+ else:
472
+ mapped_name = "layers.*.moe_ffw.kernel_up_proj_EDF"
473
+ loaded_weight = split_weights[1]
474
+
475
+ mapped_name = re.sub(r"layers\.\*", f"layers.{layer_num}",
476
+ mapped_name)
477
+
478
+ mapped_model_weight = get_param(model_params, mapped_name)
479
+
480
+ if mapped_model_weight.value.shape != loaded_weight.shape:
481
+ raise ValueError(
482
+ f"Loaded shape for {split_loaded_name}: {loaded_weight.shape} "
483
+ f"does not match model shape for {mapped_name}: {mapped_model_weight.value.shape}!"
484
+ )
485
+
486
+ mapped_model_weight.value = shard_put(loaded_weight,
487
+ mapped_model_weight.sharding,
488
+ mesh=model_for_loading.mesh)
489
+ logger.debug(
490
+ f"{split_loaded_name}: {loaded_weight.shape} --> {mapped_name}: {mapped_model_weight.value.shape}"
491
+ )
492
+ if self.is_verbose:
493
+ print_param_info(mapped_model_weight, mapped_name)
494
+
495
+ def _get_layer_num(self, loaded_key: str) -> Optional[int]:
496
+ """
497
+ Extracts the layer number from a HuggingFace weight key string.
498
+ Returns the layer number (int) or None if no layer number is found.
499
+ """
500
+ match = re.search(r"layers\.(\d+)", loaded_key)
501
+ if match:
502
+ return int(match.group(1))
503
+ return None
504
+
505
+ def _get_expert_num(self, loaded_key: str) -> Optional[int]:
506
+ """
507
+ Extracts the expect number from a HuggingFace weight key string.
508
+ Returns the expect number (int) or None if no expect number is found.
509
+ """
510
+ match = re.search(r"experts\.(\d+)\.", loaded_key)
511
+ if match:
512
+ return int(match.group(1))
513
+ return None
514
+
515
+ def load_weights(self, model_for_loading: nnx.Module):
516
+ model_params = nnx.state(model_for_loading)
517
+
518
+ with jax.default_device(jax.devices("cpu")[0]):
519
+ for loaded_name, loaded_weight in self.names_and_weights_generator:
520
+ is_moe_layer = False
521
+ layer_num = self._get_layer_num(loaded_name)
522
+ expert_num = self._get_expert_num(loaded_name)
523
+ # Quantized (FP8) checkpoints unstack the expert weights, while unquantized (BF16) checkpoints keep them stacked.
524
+ is_unfused_expert = self.quantization_config is not None and expert_num is not None
525
+ is_scale = loaded_name.endswith(".weight_scale")
526
+
527
+ if is_unfused_expert:
528
+ mapped_name = self.map_loaded_to_standardized_name(
529
+ loaded_name)
530
+ model_weight = get_param(model_params, mapped_name)
531
+
532
+ if is_scale:
533
+ cast_type = model_weight.array.scale.value.dtype
534
+ else:
535
+ cast_type = model_weight.array.qvalue.value.dtype
536
+
537
+ loaded_weight = convert_torch_to_jax_with_view(
538
+ loaded_weight, cast_type)
539
+ loaded_weight = transpose_params(loaded_name,
540
+ loaded_weight,
541
+ self._transpose_map)
542
+
543
+ buffer_key = f"{mapped_name}_{'scale' if is_scale else 'qvalue'}"
544
+ if buffer_key not in self.expert_weights_buffer:
545
+ self.expert_weights_buffer[buffer_key] = {}
546
+ self.expert_weights_buffer[buffer_key][
547
+ expert_num] = loaded_weight
548
+ continue
549
+
550
+ if layer_num is not None:
551
+ is_moe_layer = (layer_num + 1) % \
552
+ self.interleave_moe_layer_step == 0
553
+ self.expert_prefix = "shared_expert." if is_moe_layer else ""
554
+
555
+ if "gate_up_proj" in loaded_name:
556
+ self._map_llama4_gate_up_proj(model_for_loading,
557
+ model_params, loaded_name,
558
+ loaded_weight)
559
+ continue
560
+
561
+ mapped_name = self.map_loaded_to_standardized_name(loaded_name)
562
+ model_weight = get_param(model_params, mapped_name)
563
+
564
+ cast_type = model_weight.value.dtype
565
+ if not isinstance(loaded_weight, jax.Array):
566
+ logger.debug(
567
+ f"Converting PyTorch tensor {loaded_name} to JAX {cast_type}"
568
+ )
569
+ loaded_weight = convert_torch_to_jax_with_view(
570
+ loaded_weight, cast_type)
571
+
572
+ if not loaded_name.endswith(".bias"):
573
+ loaded_weight = reshape_params(loaded_name, loaded_weight,
574
+ self._weight_shape_map)
575
+ loaded_weight = transpose_params(loaded_name,
576
+ loaded_weight,
577
+ self._transpose_map)
578
+ if model_weight.value.shape != loaded_weight.shape:
579
+ raise ValueError(
580
+ f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
581
+ f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
582
+ )
583
+ logger.debug(
584
+ f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
585
+ )
586
+
587
+ model_weight.value = shard_put(loaded_weight,
588
+ model_weight.sharding,
589
+ mesh=model_for_loading.mesh)
590
+ if self.is_verbose:
591
+ print_param_info(model_weight, loaded_name)
592
+
593
+ with jax.default_device(jax.devices("cpu")[0]):
594
+ for buffer_key, expert_map in self.expert_weights_buffer.items(
595
+ ):
596
+ sorted_exp_nums = sorted(expert_map.keys())
597
+ aggregated_weight = jnp.stack(
598
+ [expert_map[k] for k in sorted_exp_nums], axis=0)
599
+ is_scale = buffer_key.endswith("_scale")
600
+ base_mapped_name = buffer_key.replace("_scale",
601
+ "").replace(
602
+ "_qvalue", "")
603
+
604
+ model_weight = get_param(model_params, base_mapped_name)
605
+
606
+ assert hasattr(
607
+ model_weight, 'array'
608
+ ), f"Expected MoE weight '{base_mapped_name}' to be a quantized array (qarray)"
609
+
610
+ if is_scale:
611
+ loaded_name = f"{base_mapped_name}.array.scale.value"
612
+ if model_weight.array.scale.value.shape != aggregated_weight.shape:
613
+ raise ValueError(
614
+ f"[AGGREGATED] Loaded shape for {buffer_key}: {aggregated_weight.shape}"
615
+ f"does not match model shape for {loaded_name}: {model_weight.array.scale.value.shape}!"
616
+ )
617
+
618
+ model_weight.array.scale.value = shard_put(
619
+ aggregated_weight,
620
+ model_weight.array.scale.sharding,
621
+ mesh=model_for_loading.mesh)
622
+
623
+ elif aggregated_weight.itemsize < 2: # check model weight elem nbits < 16
624
+ loaded_name = f"{base_mapped_name}.array.qvalue.value"
625
+ if model_weight.array.qvalue.value.shape != aggregated_weight.shape:
626
+ raise ValueError(
627
+ f"[AGGREGATED] Loaded shape for {buffer_key}: {aggregated_weight.shape}"
628
+ f"does not match model shape for {loaded_name}: {model_weight.array.qvalue.value.shape}!"
629
+ )
630
+
631
+ model_weight.array.qvalue.value = shard_put(
632
+ aggregated_weight,
633
+ model_weight.array.qvalue.sharding,
634
+ mesh=model_for_loading.mesh)
635
+
636
+ logger.debug(
637
+ f"Aggregated and loaded {loaded_name}: {aggregated_weight.shape}"
638
+ )
639
+
640
+ if self.is_verbose:
641
+ print_param_info(model_weight, loaded_name)
642
+
643
+ nnx.update(model_for_loading, model_params)