tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,629 @@
1
+ import re
2
+ from typing import List, Optional, Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax import nnx
7
+ from flax.typing import PRNGKey
8
+ from jax.sharding import Mesh
9
+ from jax.sharding import PartitionSpec as P
10
+ from vllm.config import VllmConfig
11
+
12
+ from tpu_inference.layers.jax.attention.attention import AttentionMetadata
13
+ from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
14
+ from tpu_inference.layers.jax.constants import KVCacheType
15
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
16
+ from tpu_inference.layers.jax.misc import shard_put
17
+ from tpu_inference.layers.jax.moe.moe import MoE, Router
18
+ from tpu_inference.layers.jax.transformer_block import \
19
+ SharedExpertsTransformerBlock
20
+ from tpu_inference.logger import init_logger
21
+ from tpu_inference.models.jax.utils.weight_utils import (
22
+ convert_torch_to_jax_with_view, get_param, model_weights_generator,
23
+ print_param_info, reshape_params, transpose_params)
24
+
25
+ logger = init_logger(__name__)
26
+
27
+
28
+ class Llama4ForCausalLM(nnx.Module):
29
+
30
+ def __init__(self,
31
+ vllm_config: VllmConfig,
32
+ rng: PRNGKey,
33
+ mesh: Mesh,
34
+ force_random_weights: bool = False):
35
+ assert mesh is not None
36
+
37
+ self.vllm_config = vllm_config
38
+ model_config = vllm_config.model_config
39
+ text_config = model_config.hf_config.text_config
40
+
41
+ self.rng = nnx.Rngs(rng)
42
+ self.mesh = mesh
43
+ self.is_verbose = getattr(self.vllm_config.additional_config,
44
+ "is_verbose", False)
45
+
46
+ # Currently the runner will always set a mesh, so the custom default sharding (when
47
+ # no sharding is set in vllm config) doesn't take effect.
48
+ # TODO(fhzhang): figure out whether we need to actually enable this.
49
+ # strategy_dict = {"tensor_parallelism": 4, "expert_parallelism": 2}
50
+
51
+ self.vocab_size = model_config.get_vocab_size()
52
+ self.hidden_size = model_config.get_hidden_size()
53
+
54
+ dtype: jnp.dtype = jnp.bfloat16
55
+
56
+ self.num_layers: int = getattr(text_config, "num_hidden_layers", 48)
57
+
58
+ self.intermediate_size_moe: int = getattr(text_config,
59
+ "intermediate_size", 8192)
60
+ self.intermediate_size_mlp = getattr(text_config,
61
+ "intermediate_size_mlp", 16384)
62
+
63
+ # num_local_experts: uses 16 experts for Llama-4-Scout-17B-16E-Instruct and uses 128 experts Llama-4-Maverick-17B-128E-Instruct.
64
+ # The default value is set to 16 for compatibility with Llama-4-Scout.
65
+ self.num_local_experts: int = getattr(text_config, "num_local_experts",
66
+ 16)
67
+ self.hidden_act: str = getattr(text_config, "hidden_act", "silu")
68
+ self.no_rope_layer_interval = 4
69
+
70
+ # interleave_moe_layer_step has a layer step of 2 to interleave MoE and dense layers for Llama-4-Maverick-17B-128E-Instruct.
71
+ # The default value is set to 1 for compatibility with Llama-4-Scout.
72
+ self.interleave_moe_layer_step = getattr(text_config,
73
+ "interleave_moe_layer_step",
74
+ 1)
75
+
76
+ self.num_attention_heads = getattr(text_config, "num_attention_heads",
77
+ 40)
78
+ self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
79
+ 8)
80
+ self.head_dim = getattr(text_config, "head_dim", 128)
81
+
82
+ self.num_shared_experts = getattr(text_config, "num_experts_per_tok",
83
+ 1)
84
+ self.rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
85
+
86
+ self.rope_scaling = getattr(text_config, "rope_scaling", None)
87
+ if self.rope_scaling:
88
+ self.rope_scaling["scale_factor"] = self.rope_scaling.pop("factor")
89
+
90
+ self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
91
+
92
+ self.embedder = Embedder(vocab_size=self.vocab_size,
93
+ hidden_size=self.hidden_size,
94
+ dtype=dtype,
95
+ vd_sharding=(('data', 'expert', 'model'),
96
+ None),
97
+ rngs=self.rng,
98
+ random_init=force_random_weights)
99
+
100
+ self.layers = []
101
+
102
+ for i in range(self.num_layers):
103
+ # For Llama4-Scout, all layers are MoE layers.
104
+ # This can be adjusted for other variants.
105
+ is_moe_layer = (i + 1) % \
106
+ self.interleave_moe_layer_step == 0
107
+
108
+ # Llama-4-Scout config: It has "no_rope_layers": []
109
+ use_attention_rope = (i + 1) % self.no_rope_layer_interval != 0
110
+
111
+ router = Router(dtype=dtype,
112
+ hidden_size=self.hidden_size,
113
+ num_experts=self.num_local_experts,
114
+ num_experts_per_tok=1,
115
+ router_act="sigmoid",
116
+ rngs=self.rng,
117
+ activation_ffw_td=('data', None),
118
+ ed_sharding=(None, None),
119
+ random_init=force_random_weights)
120
+
121
+ moe_ffw = MoE(
122
+ dtype=dtype,
123
+ num_local_experts=self.num_local_experts,
124
+ apply_expert_weight_before_computation=True,
125
+ hidden_size=self.hidden_size,
126
+ intermediate_size_moe=self.intermediate_size_moe,
127
+ hidden_act=self.hidden_act,
128
+ router=router,
129
+ rngs=self.rng,
130
+ activation_ffw_td=('data', None),
131
+ activation_ffw_ted=('data', 'expert', None),
132
+ edf_sharding=('model', None, None),
133
+ efd_sharding=('model', None, None),
134
+ random_init=force_random_weights) if is_moe_layer else None
135
+
136
+ dense_ffw = DenseFFW(
137
+ dtype=dtype,
138
+ hidden_act=self.hidden_act,
139
+ hidden_size=self.hidden_size,
140
+ intermediate_size=self.intermediate_size_mlp,
141
+ random_init=force_random_weights,
142
+ rngs=self.rng,
143
+ df_sharding=(None, 'model'),
144
+ fd_sharding=('model', None),
145
+ activation_ffw_td=('data', None)) if not is_moe_layer else None
146
+
147
+ attn = Llama4Attention(
148
+ hidden_size=self.hidden_size,
149
+ dtype=dtype,
150
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype,
151
+ num_attention_heads=self.num_attention_heads,
152
+ num_key_value_heads=self.num_key_value_heads,
153
+ head_dim=self.head_dim,
154
+ rope_theta=500000.0,
155
+ # https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json
156
+ rope_scaling=self.rope_scaling,
157
+ rngs=self.rng,
158
+ rope_input_ordering="interleaved",
159
+ temperature_tuning=True,
160
+ temperature_tuning_scale=0.1,
161
+ temperature_tuning_floor_scale=8192,
162
+ use_qk_norm=self.use_qk_norm,
163
+ attention_chunk_size=None if use_attention_rope else 8192,
164
+ mesh=self.mesh,
165
+ random_init=force_random_weights,
166
+ activation_attention_td=('data', 'model'),
167
+ activation_q_td=('data', 'model'),
168
+ query_tnh=P('data', 'model', None),
169
+ keyvalue_skh=P('data', 'model', None),
170
+ activation_attention_out_td=('data', 'model'),
171
+ attn_o_tnh=P('data', 'model', None),
172
+ dnh_sharding=(None, 'model', None),
173
+ dkh_sharding=(None, 'model', None),
174
+ nhd_sharding=('model', None, None),
175
+ )
176
+
177
+ shared_experts = DenseFFW(
178
+ dtype=dtype,
179
+ hidden_act=self.hidden_act,
180
+ hidden_size=self.hidden_size,
181
+ intermediate_size=self.num_shared_experts *
182
+ self.intermediate_size_moe,
183
+ rngs=self.rng,
184
+ random_init=force_random_weights,
185
+ df_sharding=(None, 'model'),
186
+ fd_sharding=('model', None),
187
+ activation_ffw_td=('data', None)) if is_moe_layer else None
188
+
189
+ pre_attention_norm = RMSNorm(
190
+ dims=self.hidden_size,
191
+ random_init=force_random_weights,
192
+ epsilon=self.rms_norm_eps,
193
+ rngs=self.rng,
194
+ with_scale=True,
195
+ dtype=dtype,
196
+ activation_ffw_td=('data', None),
197
+ )
198
+
199
+ pre_mlp_norm = RMSNorm(
200
+ dims=self.hidden_size,
201
+ epsilon=self.rms_norm_eps,
202
+ rngs=self.rng,
203
+ with_scale=True,
204
+ dtype=dtype,
205
+ random_init=force_random_weights,
206
+ activation_ffw_td=('data', None),
207
+ )
208
+
209
+ block = SharedExpertsTransformerBlock(
210
+ moe_ffw=moe_ffw if is_moe_layer else None,
211
+ dense_ffw=dense_ffw if not is_moe_layer else None,
212
+ shared_experts=shared_experts if is_moe_layer else None,
213
+ attn=attn,
214
+ pre_attention_norm=pre_attention_norm,
215
+ pre_mlp_norm=pre_mlp_norm,
216
+ use_attention_rope=use_attention_rope)
217
+ self.layers.append(block)
218
+
219
+ self.final_norm = RMSNorm(
220
+ dims=self.hidden_size,
221
+ epsilon=self.rms_norm_eps,
222
+ rngs=self.rng,
223
+ with_scale=True,
224
+ dtype=dtype,
225
+ random_init=force_random_weights,
226
+ )
227
+
228
+ self.lm_head = LMhead(vocab_size=self.vocab_size,
229
+ hidden_size=self.hidden_size,
230
+ dtype=dtype,
231
+ rngs=self.rng,
232
+ vd_sharding=(('data', 'expert', 'model'), None),
233
+ dv_sharding=(None, ('data', 'expert', 'model')),
234
+ random_init=force_random_weights)
235
+ if self.is_verbose:
236
+ self._print_model_architecture()
237
+
238
+ def _print_model_architecture(self):
239
+ num_display_layers = max(self.interleave_moe_layer_step,
240
+ self.no_rope_layer_interval)
241
+
242
+ logger.info("### Embedding ###")
243
+ nnx.display(self.embedder)
244
+
245
+ logger.info(f"\n### First {num_display_layers} Layers ###")
246
+ # Loop through the slice and display each layer
247
+ for i, layer in enumerate(self.layers[:num_display_layers]):
248
+ logger.info(f"\n--- Layer {i} ---")
249
+ nnx.display(layer)
250
+
251
+ logger.info("\n### LM Head ###")
252
+ nnx.display(self.lm_head)
253
+
254
+ def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
255
+ # NOTE: Since we are using nnx.eval_shape to init the model,
256
+ # we have to pass dynamic arrays here for __call__'s usage.
257
+ self.rng = nnx.Rngs(rng)
258
+
259
+ weight_loader = Llama4WeightLoader(
260
+ vllm_config=self.vllm_config,
261
+ hidden_size=self.hidden_size,
262
+ attn_heads=self.num_attention_heads,
263
+ num_key_value_heads=self.num_key_value_heads,
264
+ attn_head_dim=self.head_dim)
265
+ weight_loader.load_weights(self)
266
+
267
+ def __call__(
268
+ self,
269
+ kv_caches: List[jax.Array],
270
+ input_ids: jax.Array,
271
+ attention_metadata: AttentionMetadata,
272
+ *args,
273
+ ) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]:
274
+ is_prefill = False
275
+ x_TD = self.embedder.encode(input_ids)
276
+
277
+ for (i, block) in enumerate(self.layers):
278
+ kv_cache = kv_caches[i]
279
+ new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
280
+ attention_metadata)
281
+ jax.block_until_ready(x_TD)
282
+ kv_caches[i] = new_kv_cache
283
+
284
+ final_activation_TD = self.final_norm(x_TD)
285
+
286
+ return kv_caches, final_activation_TD, []
287
+
288
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
289
+ logits_TV = jnp.dot(hidden_states,
290
+ self.lm_head.input_embedding_table_DV.value)
291
+ return logits_TV
292
+
293
+
294
+ class Llama4WeightLoader:
295
+
296
+ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
297
+ num_key_value_heads, attn_head_dim):
298
+ self.names_and_weights_generator = model_weights_generator(
299
+ model_name_or_path=vllm_config.model_config.model,
300
+ framework="pt",
301
+ filter_regex="language_model",
302
+ download_dir=vllm_config.load_config.download_dir)
303
+ self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
304
+ False)
305
+ self.interleave_moe_layer_step = getattr(
306
+ vllm_config.model_config.hf_config.text_config,
307
+ "interleave_moe_layer_step", 1)
308
+
309
+ self.quantization_config = getattr(vllm_config.model_config.hf_config,
310
+ "quantization_config", None)
311
+ self.expert_weights_buffer = {}
312
+ self.expert_prefix = "shared_expert."
313
+
314
+ transpose_mappings_to_quantization = {
315
+ "down_proj": (1, 0),
316
+ "gate_proj": (1, 0),
317
+ "up_proj": (1, 0),
318
+ }
319
+
320
+ self._transpose_map = {
321
+ "q_proj": (2, 0, 1),
322
+ "k_proj": (2, 0, 1),
323
+ "v_proj": (2, 0, 1),
324
+ "router": (1, 0),
325
+ f"{self.expert_prefix}down_proj": (1, 0),
326
+ f"{self.expert_prefix}gate_proj": (1, 0),
327
+ f"{self.expert_prefix}up_proj": (1, 0),
328
+ "feed_forward.down_proj": (1, 0),
329
+ "feed_forward.gate_proj": (1, 0),
330
+ "feed_forward.up_proj": (1, 0),
331
+ "o_proj": (1, 2, 0),
332
+ "lm_head": (1, 0),
333
+ }
334
+
335
+ if self.quantization_config and self.expert_prefix:
336
+ self._transpose_map.update(transpose_mappings_to_quantization)
337
+
338
+ self._weight_shape_map = {
339
+ "q_proj": (attn_heads, attn_head_dim, hidden_size),
340
+ "k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
341
+ "v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
342
+ # o_proj is inverted: https://github.com/huggingface/transformers/blob/v4.53.2/src/transformers/models/llama4/modeling_llama4.py#L298
343
+ "o_proj": (hidden_size, attn_heads, attn_head_dim),
344
+ }
345
+
346
+ # Set the mappings from loaded parameter keys to standardized names.\
347
+ # 1. EXPERT_MAPPINGS_FUSED: Used for non-quantized (e.g., BF16) checkpoints.
348
+ # - This format typically comes from standard checkpoints where 'gate' and 'up' projection weights might be combined (FUSED) into a single tensor.
349
+ # - Expert weights are usually stacked, with the expert dimension (E) being the first dimension.
350
+ EXPERT_MAPPINGS_FUSED = {
351
+ "language_model.model.layers.*.feed_forward.experts.down_proj":
352
+ "layers.*.moe_ffw.kernel_down_proj_EFD",
353
+ "language_model.model.layers.*.feed_forward.experts.gate_up_proj":
354
+ "layers.*.moe_ffw.kernel_up_proj_EDF",
355
+ }
356
+
357
+ # 2. EXPERT_MAPPINGS_UNFUSED: Specifically designed for quantized checkpoints (e.g., FP8).
358
+ # - Quantized checkpoints store each expert's weights separately and explicitly separate the 'weight' (quantized value) from the 'weight_scale' (quantization scale).
359
+ # - The mapping captures both the `.weight` and `.weight_scale` components. This allows the loader to aggregate (stack) the individual expert weights and scales.
360
+ EXPERT_MAPPINGS_UNFUSED = {
361
+ "language_model.model.layers.*.feed_forward.experts.*.down_proj.weight":
362
+ "layers.*.moe_ffw.kernel_down_proj_EFD",
363
+ "language_model.model.layers.*.feed_forward.experts.*.down_proj.weight_scale":
364
+ "layers.*.moe_ffw.kernel_down_proj_EFD",
365
+ "language_model.model.layers.*.feed_forward.experts.*.gate_proj.weight":
366
+ "layers.*.moe_ffw.kernel_gating_EDF",
367
+ "language_model.model.layers.*.feed_forward.experts.*.gate_proj.weight_scale":
368
+ "layers.*.moe_ffw.kernel_gating_EDF",
369
+ "language_model.model.layers.*.feed_forward.experts.*.up_proj.weight":
370
+ "layers.*.moe_ffw.kernel_up_proj_EDF",
371
+ "language_model.model.layers.*.feed_forward.experts.*.up_proj.weight_scale":
372
+ "layers.*.moe_ffw.kernel_up_proj_EDF",
373
+ }
374
+
375
+ self._loaded_to_standardized_keys = {
376
+ "language_model.model.embed_tokens.weight":
377
+ "embedder.input_embedding_table_VD",
378
+ "language_model.lm_head.weight":
379
+ "lm_head.input_embedding_table_DV",
380
+ "language_model.model.norm.weight":
381
+ "final_norm.scale",
382
+ "language_model.model.layers.*.input_layernorm.weight":
383
+ "layers.*.pre_attention_norm.scale",
384
+ "language_model.model.layers.*.post_attention_layernorm.weight":
385
+ "layers.*.pre_mlp_norm.scale",
386
+ "language_model.model.layers.*.self_attn.q_proj.weight":
387
+ "layers.*.attn.kernel_q_proj_DNH",
388
+ "language_model.model.layers.*.self_attn.k_proj.weight":
389
+ "layers.*.attn.kernel_k_proj_DKH",
390
+ "language_model.model.layers.*.self_attn.v_proj.weight":
391
+ "layers.*.attn.kernel_v_proj_DKH",
392
+ "language_model.model.layers.*.self_attn.o_proj.weight":
393
+ "layers.*.attn.kernel_o_proj_NHD",
394
+ "language_model.model.layers.*.feed_forward.router.weight":
395
+ "layers.*.moe_ffw.router.kernel_DE",
396
+ # shared experts
397
+ "language_model.model.layers.*.feed_forward.shared_expert.down_proj.weight":
398
+ "layers.*.shared_experts.kernel_down_proj_FD",
399
+ "language_model.model.layers.*.feed_forward.shared_expert.gate_proj.weight":
400
+ "layers.*.shared_experts.kernel_gating_DF",
401
+ "language_model.model.layers.*.feed_forward.shared_expert.up_proj.weight":
402
+ "layers.*.shared_experts.kernel_up_proj_DF",
403
+ # dense layers
404
+ "language_model.model.layers.*.feed_forward.down_proj.weight":
405
+ "layers.*.dense_ffw.kernel_down_proj_FD",
406
+ "language_model.model.layers.*.feed_forward.up_proj.weight":
407
+ "layers.*.dense_ffw.kernel_up_proj_DF",
408
+ "language_model.model.layers.*.feed_forward.gate_proj.weight":
409
+ "layers.*.dense_ffw.kernel_gating_DF",
410
+ }
411
+
412
+ if self.quantization_config is None:
413
+ self._loaded_to_standardized_keys.update(EXPERT_MAPPINGS_FUSED)
414
+ else:
415
+ self._loaded_to_standardized_keys.update(EXPERT_MAPPINGS_UNFUSED)
416
+
417
+ def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
418
+ # Find the corresponding model key using the HF key
419
+ if "layer" in loaded_key:
420
+ layer_num = self._get_layer_num(loaded_key)
421
+ layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
422
+
423
+ expert_match = re.search(r"experts\.(\d+)", layer_key)
424
+ if expert_match:
425
+ # Key for lookup eg: layers.*.feed_forward.experts.*.down_proj.weight
426
+ layer_key = re.sub(r"experts\.\d+", "experts.*", layer_key)
427
+
428
+ mapped_key = self._loaded_to_standardized_keys.get(
429
+ layer_key, loaded_key)
430
+ mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
431
+ mapped_key)
432
+ else:
433
+ mapped_key = self._loaded_to_standardized_keys.get(
434
+ loaded_key, loaded_key)
435
+ return mapped_key
436
+
437
+ def _map_llama4_gate_up_proj(self, model_for_loading: nnx.Module,
438
+ model_params: nnx.State, loaded_name: str,
439
+ loaded_weight: jax.Array):
440
+ """HF's gate_up_proj is a fused tensor of gate and up projections. It needs to be split."""
441
+
442
+ cast_type = jnp.dtype(jnp.bfloat16)
443
+ # loaded_weight is a jax.Array when framework="flax", otherwise it's bfloat16
444
+ if not isinstance(loaded_weight, jax.Array):
445
+ loaded_weight = convert_torch_to_jax_with_view(
446
+ loaded_weight, cast_type)
447
+
448
+ split_weights = jnp.split(loaded_weight, 2, axis=-1)
449
+ layer_num = self._get_layer_num(loaded_name)
450
+
451
+ for split_type in ["gate", "up"]:
452
+ split_loaded_name = loaded_name.replace("gate_up_proj",
453
+ f"{split_type}_proj")
454
+ if split_type == "gate":
455
+ mapped_name = "layers.*.moe_ffw.kernel_gating_EDF"
456
+ loaded_weight = split_weights[0]
457
+ else:
458
+ mapped_name = "layers.*.moe_ffw.kernel_up_proj_EDF"
459
+ loaded_weight = split_weights[1]
460
+
461
+ mapped_name = re.sub(r"layers\.\*", f"layers.{layer_num}",
462
+ mapped_name)
463
+
464
+ mapped_model_weight = get_param(model_params, mapped_name)
465
+
466
+ if mapped_model_weight.value.shape != loaded_weight.shape:
467
+ raise ValueError(
468
+ f"Loaded shape for {split_loaded_name}: {loaded_weight.shape} "
469
+ f"does not match model shape for {mapped_name}: {mapped_model_weight.value.shape}!"
470
+ )
471
+
472
+ mapped_model_weight.value = shard_put(loaded_weight,
473
+ mapped_model_weight.sharding,
474
+ mesh=model_for_loading.mesh)
475
+ logger.debug(
476
+ f"{split_loaded_name}: {loaded_weight.shape} --> {mapped_name}: {mapped_model_weight.value.shape}"
477
+ )
478
+ if self.is_verbose:
479
+ print_param_info(mapped_model_weight, mapped_name)
480
+
481
+ def _get_layer_num(self, loaded_key: str) -> Optional[int]:
482
+ """
483
+ Extracts the layer number from a HuggingFace weight key string.
484
+ Returns the layer number (int) or None if no layer number is found.
485
+ """
486
+ match = re.search(r"layers\.(\d+)", loaded_key)
487
+ if match:
488
+ return int(match.group(1))
489
+ return None
490
+
491
+ def _get_expert_num(self, loaded_key: str) -> Optional[int]:
492
+ """
493
+ Extracts the expect number from a HuggingFace weight key string.
494
+ Returns the expect number (int) or None if no expect number is found.
495
+ """
496
+ match = re.search(r"experts\.(\d+)\.", loaded_key)
497
+ if match:
498
+ return int(match.group(1))
499
+ return None
500
+
501
+ def load_weights(self, model_for_loading: nnx.Module):
502
+ model_params = nnx.state(model_for_loading)
503
+
504
+ with jax.default_device(jax.devices("cpu")[0]):
505
+ for loaded_name, loaded_weight in self.names_and_weights_generator:
506
+ is_moe_layer = False
507
+ layer_num = self._get_layer_num(loaded_name)
508
+ expert_num = self._get_expert_num(loaded_name)
509
+ # Quantized (FP8) checkpoints unstack the expert weights, while unquantized (BF16) checkpoints keep them stacked.
510
+ is_unfused_expert = self.quantization_config is not None and expert_num is not None
511
+ is_scale = loaded_name.endswith(".weight_scale")
512
+
513
+ if is_unfused_expert:
514
+ mapped_name = self.map_loaded_to_standardized_name(
515
+ loaded_name)
516
+ model_weight = get_param(model_params, mapped_name)
517
+
518
+ if is_scale:
519
+ cast_type = model_weight.array.scale.value.dtype
520
+ else:
521
+ cast_type = model_weight.array.qvalue.value.dtype
522
+
523
+ loaded_weight = convert_torch_to_jax_with_view(
524
+ loaded_weight, cast_type)
525
+ loaded_weight = transpose_params(loaded_name,
526
+ loaded_weight,
527
+ self._transpose_map)
528
+
529
+ buffer_key = f"{mapped_name}_{'scale' if is_scale else 'qvalue'}"
530
+ if buffer_key not in self.expert_weights_buffer:
531
+ self.expert_weights_buffer[buffer_key] = {}
532
+ self.expert_weights_buffer[buffer_key][
533
+ expert_num] = loaded_weight
534
+ continue
535
+
536
+ if layer_num is not None:
537
+ is_moe_layer = (layer_num + 1) % \
538
+ self.interleave_moe_layer_step == 0
539
+ self.expert_prefix = "shared_expert." if is_moe_layer else ""
540
+
541
+ if "gate_up_proj" in loaded_name:
542
+ self._map_llama4_gate_up_proj(model_for_loading,
543
+ model_params, loaded_name,
544
+ loaded_weight)
545
+ continue
546
+
547
+ mapped_name = self.map_loaded_to_standardized_name(loaded_name)
548
+ model_weight = get_param(model_params, mapped_name)
549
+
550
+ cast_type = model_weight.value.dtype
551
+ if not isinstance(loaded_weight, jax.Array):
552
+ logger.debug(
553
+ f"Converting PyTorch tensor {loaded_name} to JAX {cast_type}"
554
+ )
555
+ loaded_weight = convert_torch_to_jax_with_view(
556
+ loaded_weight, cast_type)
557
+
558
+ if not loaded_name.endswith(".bias"):
559
+ loaded_weight = reshape_params(loaded_name, loaded_weight,
560
+ self._weight_shape_map)
561
+ loaded_weight = transpose_params(loaded_name,
562
+ loaded_weight,
563
+ self._transpose_map)
564
+ if model_weight.value.shape != loaded_weight.shape:
565
+ raise ValueError(
566
+ f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
567
+ f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
568
+ )
569
+ logger.debug(
570
+ f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
571
+ )
572
+
573
+ model_weight.value = shard_put(loaded_weight,
574
+ model_weight.sharding,
575
+ mesh=model_for_loading.mesh)
576
+ if self.is_verbose:
577
+ print_param_info(model_weight, loaded_name)
578
+
579
+ with jax.default_device(jax.devices("cpu")[0]):
580
+ for buffer_key, expert_map in self.expert_weights_buffer.items(
581
+ ):
582
+ sorted_exp_nums = sorted(expert_map.keys())
583
+ aggregated_weight = jnp.stack(
584
+ [expert_map[k] for k in sorted_exp_nums], axis=0)
585
+ is_scale = buffer_key.endswith("_scale")
586
+ base_mapped_name = buffer_key.replace("_scale",
587
+ "").replace(
588
+ "_qvalue", "")
589
+
590
+ model_weight = get_param(model_params, base_mapped_name)
591
+
592
+ assert hasattr(
593
+ model_weight, 'array'
594
+ ), f"Expected MoE weight '{base_mapped_name}' to be a quantized array (qarray)"
595
+
596
+ if is_scale:
597
+ loaded_name = f"{base_mapped_name}.array.scale.value"
598
+ if model_weight.array.scale.value.shape != aggregated_weight.shape:
599
+ raise ValueError(
600
+ f"[AGGREGATED] Loaded shape for {buffer_key}: {aggregated_weight.shape}"
601
+ f"does not match model shape for {loaded_name}: {model_weight.array.scale.value.shape}!"
602
+ )
603
+
604
+ model_weight.array.scale.value = shard_put(
605
+ aggregated_weight,
606
+ model_weight.array.scale.sharding,
607
+ mesh=model_for_loading.mesh)
608
+
609
+ elif aggregated_weight.itemsize < 2: # check model weight elem nbits < 16
610
+ loaded_name = f"{base_mapped_name}.array.qvalue.value"
611
+ if model_weight.array.qvalue.value.shape != aggregated_weight.shape:
612
+ raise ValueError(
613
+ f"[AGGREGATED] Loaded shape for {buffer_key}: {aggregated_weight.shape}"
614
+ f"does not match model shape for {loaded_name}: {model_weight.array.qvalue.value.shape}!"
615
+ )
616
+
617
+ model_weight.array.qvalue.value = shard_put(
618
+ aggregated_weight,
619
+ model_weight.array.qvalue.sharding,
620
+ mesh=model_for_loading.mesh)
621
+
622
+ logger.debug(
623
+ f"Aggregated and loaded {loaded_name}: {aggregated_weight.shape}"
624
+ )
625
+
626
+ if self.is_verbose:
627
+ print_param_info(model_weight, loaded_name)
628
+
629
+ nnx.update(model_for_loading, model_params)