tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,473 @@
1
+ import re
2
+ from typing import List, Optional, Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax import nnx
7
+ from flax.typing import PRNGKey
8
+ from jax.sharding import Mesh
9
+ from jax.sharding import PartitionSpec as P
10
+ from vllm.config import VllmConfig
11
+
12
+ from tpu_inference.layers.jax.attention.attention import AttentionMetadata
13
+ from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
14
+ from tpu_inference.layers.jax.constants import KVCacheType
15
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
16
+ from tpu_inference.layers.jax.misc import shard_put
17
+ from tpu_inference.layers.jax.moe.moe import MoE, Router
18
+ from tpu_inference.layers.jax.transformer_block import \
19
+ SharedExpertsTransformerBlock
20
+ from tpu_inference.logger import init_logger
21
+ from tpu_inference.models.jax.utils.weight_utils import (
22
+ get_param, model_weights_generator, print_param_info, reshape_params,
23
+ transpose_params)
24
+
25
+ logger = init_logger(__name__)
26
+
27
+
28
+ class Llama4ForCausalLM(nnx.Module):
29
+
30
+ def __init__(self,
31
+ vllm_config: VllmConfig,
32
+ rng: PRNGKey,
33
+ mesh: Mesh,
34
+ force_random_weights: bool = False):
35
+ assert mesh is not None
36
+
37
+ self.vllm_config = vllm_config
38
+ model_config = vllm_config.model_config
39
+ text_config = model_config.hf_config.text_config
40
+
41
+ self.rng = nnx.Rngs(rng)
42
+ self.mesh = mesh
43
+ self.is_verbose = getattr(self.vllm_config.additional_config,
44
+ "is_verbose", False)
45
+
46
+ # Currently the runner will always set a mesh, so the custom default sharding (when
47
+ # no sharding is set in vllm config) doesn't take effect.
48
+ # TODO(fhzhang): figure out whether we need to actually enable this.
49
+ # strategy_dict = {"tensor_parallelism": 4, "expert_parallelism": 2}
50
+
51
+ # TODO(fhzhang): remove these once we confirm that the values we get from config are good.
52
+ # self.hidden_size: int = 5120
53
+ # vocab_size = 202048
54
+ self.vocab_size = model_config.get_vocab_size()
55
+ self.hidden_size = model_config.get_hidden_size()
56
+
57
+ dtype: jnp.dtype = jnp.bfloat16
58
+
59
+ self.num_layers: int = getattr(text_config, "num_hidden_layers", 48)
60
+
61
+ self.intermediate_size_moe: int = getattr(text_config,
62
+ "intermediate_size", 8192)
63
+ self.intermediate_size_mlp = getattr(text_config,
64
+ "intermediate_size_mlp", 16384)
65
+
66
+ # num_local_experts: uses 16 experts for Llama-4-Scout-17B-16E-Instruct and uses 128 experts Llama-4-Maverick-17B-128E-Instruct.
67
+ # The default value is set to 16 for compatibility with Llama-4-Scout.
68
+ self.num_local_experts: int = getattr(text_config, "num_local_experts",
69
+ 16)
70
+ self.hidden_act: str = getattr(text_config, "hidden_act", "silu")
71
+ self.no_rope_layer_interval = getattr(text_config, "no_rope_layers",
72
+ [])
73
+
74
+ # interleave_moe_layer_step has a layer step of 2 to interleave MoE and dense layers for Llama-4-Maverick-17B-128E-Instruct.
75
+ # The default value is set to 1 for compatibility with Llama-4-Scout.
76
+ self.interleave_moe_layer_step = getattr(text_config,
77
+ "interleave_moe_layer_step",
78
+ 1)
79
+
80
+ self.num_attention_heads = getattr(text_config, "num_attention_heads",
81
+ 40)
82
+ self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
83
+ 8)
84
+ self.head_dim = getattr(text_config, "head_dim", 128)
85
+
86
+ self.num_shared_experts = getattr(text_config, "num_experts_per_tok",
87
+ 1)
88
+ self.rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
89
+
90
+ self.embedder = Embedder(vocab_size=self.vocab_size,
91
+ hidden_size=self.hidden_size,
92
+ dtype=dtype,
93
+ vd_sharding=(('data', 'expert', 'model'),
94
+ None),
95
+ rngs=self.rng,
96
+ random_init=force_random_weights)
97
+
98
+ self.layers = []
99
+
100
+ for i in range(self.num_layers):
101
+ # For Llama4-Scout, all layers are MoE layers.
102
+ # This can be adjusted for other variants.
103
+ is_moe_layer = (i + 1) % \
104
+ self.interleave_moe_layer_step == 0
105
+
106
+ # Llama-4-Scout config: It has "no_rope_layers": []
107
+ use_attention_rope = (i + 1) not in self.no_rope_layer_interval
108
+
109
+ router = Router(dtype=dtype,
110
+ hidden_size=self.hidden_size,
111
+ num_experts=self.num_local_experts,
112
+ num_experts_per_tok=1,
113
+ router_act="sigmoid",
114
+ rngs=self.rng,
115
+ activation_ffw_td=('data', None),
116
+ ed_sharding=(None, 'expert'),
117
+ random_init=force_random_weights)
118
+
119
+ custom_module = MoE(
120
+ dtype=dtype,
121
+ num_local_experts=self.num_local_experts,
122
+ apply_expert_weight_before_computation=True,
123
+ hidden_size=self.hidden_size,
124
+ intermediate_size_moe=self.intermediate_size_moe,
125
+ hidden_act=self.hidden_act,
126
+ router=router,
127
+ rngs=self.rng,
128
+ activation_ffw_td=('data', None),
129
+ activation_ffw_ted=('data', 'expert', None),
130
+ edf_sharding=('expert', None, 'model'),
131
+ efd_sharding=('expert', 'model', None),
132
+ random_init=force_random_weights
133
+ ) if is_moe_layer else DenseFFW(
134
+ dtype=dtype,
135
+ hidden_act=self.hidden_act,
136
+ hidden_size=self.hidden_size,
137
+ intermediate_size=self.intermediate_size_mlp,
138
+ random_init=force_random_weights,
139
+ rngs=self.rng,
140
+ df_sharding=(None, 'model'),
141
+ fd_sharding=('model', None),
142
+ activation_ffw_td=('data', None))
143
+
144
+ attn = Llama4Attention(
145
+ hidden_size=self.hidden_size,
146
+ dtype=dtype,
147
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype,
148
+ num_attention_heads=self.num_attention_heads,
149
+ num_key_value_heads=self.num_key_value_heads,
150
+ head_dim=self.head_dim,
151
+ rope_theta=500000.0,
152
+ # https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json
153
+ rope_scaling={
154
+ "scale_factor": 16.0,
155
+ "low_freq_factor": 1.0,
156
+ "high_freq_factor": 1.0,
157
+ "original_max_position_embeddings": 8192
158
+ },
159
+ rngs=self.rng,
160
+ rope_input_ordering="interleaved",
161
+ temperature_tuning=True,
162
+ temperature_tuning_scale=0.1,
163
+ temperature_tuning_floor_scale=8192,
164
+ use_qk_norm=True,
165
+ attention_chunk_size=None if use_attention_rope else 8192,
166
+ mesh=self.mesh,
167
+ random_init=force_random_weights,
168
+ activation_attention_td=('data', 'model'),
169
+ activation_q_td=('data', 'model'),
170
+ query_tnh=P('data', 'model', None),
171
+ keyvalue_skh=P('data', 'model', None),
172
+ activation_attention_out_td=('data', 'model'),
173
+ attn_o_tnh=P('data', 'model', None),
174
+ dnh_sharding=(None, 'model', None),
175
+ dkh_sharding=(None, 'model', None),
176
+ nhd_sharding=('model', None, None),
177
+ )
178
+
179
+ shared_experts = DenseFFW(
180
+ dtype=dtype,
181
+ hidden_act=self.hidden_act,
182
+ hidden_size=self.hidden_size,
183
+ intermediate_size=self.num_shared_experts *
184
+ self.intermediate_size_moe,
185
+ rngs=self.rng,
186
+ random_init=force_random_weights,
187
+ df_sharding=(None, 'model'),
188
+ fd_sharding=('model', None),
189
+ activation_ffw_td=('data', None)) if is_moe_layer else None
190
+
191
+ pre_attention_norm = RMSNorm(
192
+ dims=self.hidden_size,
193
+ random_init=force_random_weights,
194
+ epsilon=self.rms_norm_eps,
195
+ rngs=self.rng,
196
+ with_scale=True,
197
+ dtype=dtype,
198
+ )
199
+
200
+ pre_mlp_norm = RMSNorm(
201
+ dims=self.hidden_size,
202
+ epsilon=self.rms_norm_eps,
203
+ rngs=self.rng,
204
+ with_scale=True,
205
+ dtype=dtype,
206
+ random_init=force_random_weights,
207
+ )
208
+
209
+ block = SharedExpertsTransformerBlock(
210
+ custom_module=custom_module,
211
+ attn=attn,
212
+ pre_attention_norm=pre_attention_norm,
213
+ pre_mlp_norm=pre_mlp_norm,
214
+ shared_experts=shared_experts,
215
+ use_attention_rope=use_attention_rope)
216
+ self.layers.append(block)
217
+
218
+ self.final_norm = RMSNorm(
219
+ dims=self.hidden_size,
220
+ epsilon=self.rms_norm_eps,
221
+ rngs=self.rng,
222
+ with_scale=True,
223
+ dtype=dtype,
224
+ random_init=force_random_weights,
225
+ )
226
+
227
+ self.lm_head = LMhead(vocab_size=self.vocab_size,
228
+ hidden_size=self.hidden_size,
229
+ dtype=dtype,
230
+ rngs=self.rng,
231
+ vd_sharding=(('data', 'expert', 'model'), None),
232
+ dv_sharding=(None, ('data', 'expert', 'model')),
233
+ random_init=force_random_weights)
234
+ if self.is_verbose:
235
+ self._print_model_architecture()
236
+
237
+ def _print_model_architecture(self):
238
+ num_display_layers = max(self.interleave_moe_layer_step,
239
+ self.no_rope_layer_interval)
240
+
241
+ logger.info("### Embedding ###")
242
+ nnx.display(self.embedder)
243
+
244
+ logger.info(f"\n### First {num_display_layers} Layers ###")
245
+ # Loop through the slice and display each layer
246
+ for i, layer in enumerate(self.layers[:num_display_layers]):
247
+ logger.info(f"\n--- Layer {i} ---")
248
+ nnx.display(layer)
249
+
250
+ logger.info("\n### LM Head ###")
251
+ nnx.display(self.lm_head)
252
+
253
+ def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
254
+ # NOTE: Since we are using nnx.eval_shape to init the model,
255
+ # we have to pass dynamic arrays here for __call__'s usage.
256
+ self.rng = nnx.Rngs(rng)
257
+
258
+ weight_loader = Llama4WeightLoader(
259
+ vllm_config=self.vllm_config,
260
+ hidden_size=self.hidden_size,
261
+ attn_heads=self.num_attention_heads,
262
+ num_key_value_heads=self.num_key_value_heads,
263
+ attn_head_dim=self.head_dim)
264
+ weight_loader.load_weights(self)
265
+
266
+ def __call__(
267
+ self,
268
+ kv_caches: List[jax.Array],
269
+ input_ids: jax.Array,
270
+ attention_metadata: AttentionMetadata,
271
+ *args,
272
+ ) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]:
273
+ is_prefill = False
274
+ x_TD = self.embedder.encode(input_ids)
275
+
276
+ for (i, block) in enumerate(self.layers):
277
+ kv_cache = kv_caches[i]
278
+ new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
279
+ attention_metadata)
280
+ jax.block_until_ready(x_TD)
281
+ kv_caches[i] = new_kv_cache
282
+
283
+ final_activation_TD = self.final_norm(x_TD)
284
+
285
+ return kv_caches, final_activation_TD, []
286
+
287
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
288
+ logits_TV = jnp.dot(hidden_states,
289
+ self.lm_head.input_embedding_table_DV.value)
290
+ return logits_TV
291
+
292
+
293
+ class Llama4WeightLoader:
294
+
295
+ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
296
+ num_key_value_heads, attn_head_dim):
297
+ self.names_and_weights_generator = model_weights_generator(
298
+ model_name_or_path=vllm_config.model_config.model,
299
+ framework="flax",
300
+ filter_regex="language_model",
301
+ download_dir=vllm_config.load_config.download_dir)
302
+ self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
303
+ False)
304
+ self.interleave_moe_layer_step = getattr(
305
+ vllm_config.model_config.hf_config.text_config,
306
+ "interleave_moe_layer_step", 1)
307
+
308
+ self.expert_prefix = "shared_expert."
309
+ self._transpose_map = {
310
+ "q_proj": (2, 0, 1),
311
+ "k_proj": (2, 0, 1),
312
+ "v_proj": (2, 0, 1),
313
+ "router": (1, 0),
314
+ f"{self.expert_prefix}down_proj": (1, 0),
315
+ f"{self.expert_prefix}gate_proj": (1, 0),
316
+ f"{self.expert_prefix}up_proj": (1, 0),
317
+ "feed_forward.down_proj": (1, 0),
318
+ "feed_forward.gate_proj": (1, 0),
319
+ "feed_forward.up_proj": (1, 0),
320
+ "o_proj": (1, 2, 0),
321
+ "lm_head": (1, 0),
322
+ }
323
+
324
+ self._weight_shape_map = {
325
+ "q_proj": (attn_heads, attn_head_dim, hidden_size),
326
+ "k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
327
+ "v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
328
+ # o_proj is inverted: https://github.com/huggingface/transformers/blob/v4.53.2/src/transformers/models/llama4/modeling_llama4.py#L298
329
+ "o_proj": (hidden_size, attn_heads, attn_head_dim),
330
+ }
331
+
332
+ # Set the mappings from loaded parameter keys to standardized names.
333
+ self._loaded_to_standardized_keys = {
334
+ "language_model.model.embed_tokens.weight":
335
+ "embedder.input_embedding_table_VD",
336
+ "language_model.lm_head.weight":
337
+ "lm_head.input_embedding_table_DV",
338
+ "language_model.model.norm.weight":
339
+ "final_norm.scale",
340
+ "language_model.model.layers.*.input_layernorm.weight":
341
+ "layers.*.pre_attention_norm.scale",
342
+ "language_model.model.layers.*.post_attention_layernorm.weight":
343
+ "layers.*.pre_mlp_norm.scale",
344
+ "language_model.model.layers.*.self_attn.q_proj.weight":
345
+ "layers.*.attn.kernel_q_proj_DNH",
346
+ "language_model.model.layers.*.self_attn.k_proj.weight":
347
+ "layers.*.attn.kernel_k_proj_DKH",
348
+ "language_model.model.layers.*.self_attn.v_proj.weight":
349
+ "layers.*.attn.kernel_v_proj_DKH",
350
+ "language_model.model.layers.*.self_attn.o_proj.weight":
351
+ "layers.*.attn.kernel_o_proj_NHD",
352
+ "language_model.model.layers.*.feed_forward.router.weight":
353
+ "layers.*.custom_module.router.kernel_DE",
354
+ "language_model.model.layers.*.feed_forward.experts.down_proj":
355
+ "layers.*.custom_module.kernel_down_proj_EFD",
356
+ "language_model.model.layers.*.feed_forward.experts.gate_up_proj":
357
+ "layers.*.custom_module.kernel_up_proj_EDF",
358
+ "language_model.model.layers.*.feed_forward.shared_expert.down_proj.weight":
359
+ "layers.*.shared_experts.kernel_down_proj_FD",
360
+ "language_model.model.layers.*.feed_forward.shared_expert.gate_proj.weight":
361
+ "layers.*.shared_experts.kernel_gating_DF",
362
+ "language_model.model.layers.*.feed_forward.shared_expert.up_proj.weight":
363
+ "layers.*.shared_experts.kernel_up_proj_DF",
364
+ "language_model.model.layers.*.feed_forward.down_proj.weight":
365
+ "layers.*.custom_module.kernel_down_proj_FD",
366
+ "language_model.model.layers.*.feed_forward.up_proj.weight":
367
+ "layers.*.custom_module.kernel_up_proj_DF",
368
+ "language_model.model.layers.*.feed_forward.gate_proj.weight":
369
+ "layers.*.custom_module.kernel_gating_DF",
370
+ }
371
+
372
+ def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
373
+ # Find the corresponding model key using the HF key
374
+ if "layer" in loaded_key:
375
+ layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
376
+ layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
377
+ mapped_key = self._loaded_to_standardized_keys.get(
378
+ layer_key, loaded_key)
379
+ mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
380
+ mapped_key)
381
+ else:
382
+ mapped_key = self._loaded_to_standardized_keys.get(
383
+ loaded_key, loaded_key)
384
+ return mapped_key
385
+
386
+ def _map_llama4_gate_up_proj(self, model_for_loading: nnx.Module,
387
+ model_params: nnx.State, loaded_name: str,
388
+ loaded_weight: jax.Array):
389
+ """HF's gate_up_proj is a fused tensor of gate and up projections. It needs to be split."""
390
+ # gate_proj is first & up_proj is second
391
+ split_weights = jnp.split(loaded_weight, 2, axis=-1)
392
+
393
+ for split_type in ["gate", "up"]:
394
+ split_loaded_name = loaded_name.replace("gate_up_proj",
395
+ f"{split_type}_proj")
396
+ if split_type == "gate":
397
+ mapped_name = "layers.*.custom_module.kernel_gating_EDF"
398
+ loaded_weight = split_weights[0]
399
+ else:
400
+ mapped_name = "layers.*.custom_module.kernel_up_proj_EDF"
401
+ loaded_weight = split_weights[1]
402
+
403
+ layer_num = re.search(r"layers\.(\d+)", split_loaded_name).group(1)
404
+ mapped_name = re.sub(r"layers\.\*", f"layers.{layer_num}",
405
+ mapped_name)
406
+ mapped_model_weight = get_param(model_params, mapped_name)
407
+
408
+ if mapped_model_weight.value.shape != loaded_weight.shape:
409
+ raise ValueError(
410
+ f"Loaded shape for {split_loaded_name}: {loaded_weight.shape} "
411
+ f"does not match model shape for {mapped_name}: {mapped_model_weight.value.shape}!"
412
+ )
413
+ mapped_model_weight.value = shard_put(loaded_weight,
414
+ mapped_model_weight.sharding,
415
+ mesh=model_for_loading.mesh)
416
+ logger.debug(
417
+ f"{split_loaded_name}: {loaded_weight.shape} --> {mapped_name}: {mapped_model_weight.value.shape}"
418
+ )
419
+ if self.is_verbose:
420
+ print_param_info(mapped_model_weight, mapped_name)
421
+
422
+ def _get_layer_num(self, loaded_key: str) -> Optional[int]:
423
+ """
424
+ Extracts the layer number from a HuggingFace weight key string.
425
+ Returns the layer number (int) or None if no layer number is found.
426
+ """
427
+ match = re.search(r"layers\.(\d+)", loaded_key)
428
+ if match:
429
+ return int(match.group(1))
430
+ return None
431
+
432
+ def load_weights(self, model_for_loading: nnx.Module):
433
+ model_params = nnx.state(model_for_loading)
434
+
435
+ with jax.default_device(jax.devices("cpu")[0]):
436
+ for loaded_name, loaded_weight in self.names_and_weights_generator:
437
+ is_moe_layer = False
438
+ layer_num = self._get_layer_num(loaded_name)
439
+
440
+ if layer_num is not None:
441
+ is_moe_layer = (layer_num + 1) % \
442
+ self.interleave_moe_layer_step == 0
443
+ self.expert_prefix = "shared_expert." if is_moe_layer else ""
444
+
445
+ if "gate_up_proj" in loaded_name:
446
+ self._map_llama4_gate_up_proj(model_for_loading,
447
+ model_params, loaded_name,
448
+ loaded_weight)
449
+ continue
450
+ mapped_name = self.map_loaded_to_standardized_name(loaded_name)
451
+ model_weight = get_param(model_params, mapped_name)
452
+
453
+ if not loaded_name.endswith(".bias"):
454
+ loaded_weight = reshape_params(loaded_name, loaded_weight,
455
+ self._weight_shape_map)
456
+ loaded_weight = transpose_params(loaded_name,
457
+ loaded_weight,
458
+ self._transpose_map)
459
+ if model_weight.value.shape != loaded_weight.shape:
460
+ raise ValueError(
461
+ f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
462
+ f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
463
+ )
464
+ logger.debug(
465
+ f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
466
+ )
467
+ model_weight.value = shard_put(loaded_weight,
468
+ model_weight.sharding,
469
+ mesh=model_for_loading.mesh)
470
+ if self.is_verbose:
471
+ print_param_info(model_weight, loaded_name)
472
+
473
+ nnx.update(model_for_loading, model_params)