tpu-inference 0.11.1rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1rc1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1rc1.dist-info/RECORD +123 -0
- tpu_inference-0.11.1rc1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1rc1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1rc1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,473 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from flax import nnx
|
|
7
|
+
from flax.typing import PRNGKey
|
|
8
|
+
from jax.sharding import Mesh
|
|
9
|
+
from jax.sharding import PartitionSpec as P
|
|
10
|
+
from vllm.config import VllmConfig
|
|
11
|
+
|
|
12
|
+
from tpu_inference.layers.jax.attention.attention import AttentionMetadata
|
|
13
|
+
from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
|
|
14
|
+
from tpu_inference.layers.jax.constants import KVCacheType
|
|
15
|
+
from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
|
|
16
|
+
from tpu_inference.layers.jax.misc import shard_put
|
|
17
|
+
from tpu_inference.layers.jax.moe.moe import MoE, Router
|
|
18
|
+
from tpu_inference.layers.jax.transformer_block import \
|
|
19
|
+
SharedExpertsTransformerBlock
|
|
20
|
+
from tpu_inference.logger import init_logger
|
|
21
|
+
from tpu_inference.models.jax.utils.weight_utils import (
|
|
22
|
+
get_param, model_weights_generator, print_param_info, reshape_params,
|
|
23
|
+
transpose_params)
|
|
24
|
+
|
|
25
|
+
logger = init_logger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Llama4ForCausalLM(nnx.Module):
|
|
29
|
+
|
|
30
|
+
def __init__(self,
|
|
31
|
+
vllm_config: VllmConfig,
|
|
32
|
+
rng: PRNGKey,
|
|
33
|
+
mesh: Mesh,
|
|
34
|
+
force_random_weights: bool = False):
|
|
35
|
+
assert mesh is not None
|
|
36
|
+
|
|
37
|
+
self.vllm_config = vllm_config
|
|
38
|
+
model_config = vllm_config.model_config
|
|
39
|
+
text_config = model_config.hf_config.text_config
|
|
40
|
+
|
|
41
|
+
self.rng = nnx.Rngs(rng)
|
|
42
|
+
self.mesh = mesh
|
|
43
|
+
self.is_verbose = getattr(self.vllm_config.additional_config,
|
|
44
|
+
"is_verbose", False)
|
|
45
|
+
|
|
46
|
+
# Currently the runner will always set a mesh, so the custom default sharding (when
|
|
47
|
+
# no sharding is set in vllm config) doesn't take effect.
|
|
48
|
+
# TODO(fhzhang): figure out whether we need to actually enable this.
|
|
49
|
+
# strategy_dict = {"tensor_parallelism": 4, "expert_parallelism": 2}
|
|
50
|
+
|
|
51
|
+
# TODO(fhzhang): remove these once we confirm that the values we get from config are good.
|
|
52
|
+
# self.hidden_size: int = 5120
|
|
53
|
+
# vocab_size = 202048
|
|
54
|
+
self.vocab_size = model_config.get_vocab_size()
|
|
55
|
+
self.hidden_size = model_config.get_hidden_size()
|
|
56
|
+
|
|
57
|
+
dtype: jnp.dtype = jnp.bfloat16
|
|
58
|
+
|
|
59
|
+
self.num_layers: int = getattr(text_config, "num_hidden_layers", 48)
|
|
60
|
+
|
|
61
|
+
self.intermediate_size_moe: int = getattr(text_config,
|
|
62
|
+
"intermediate_size", 8192)
|
|
63
|
+
self.intermediate_size_mlp = getattr(text_config,
|
|
64
|
+
"intermediate_size_mlp", 16384)
|
|
65
|
+
|
|
66
|
+
# num_local_experts: uses 16 experts for Llama-4-Scout-17B-16E-Instruct and uses 128 experts Llama-4-Maverick-17B-128E-Instruct.
|
|
67
|
+
# The default value is set to 16 for compatibility with Llama-4-Scout.
|
|
68
|
+
self.num_local_experts: int = getattr(text_config, "num_local_experts",
|
|
69
|
+
16)
|
|
70
|
+
self.hidden_act: str = getattr(text_config, "hidden_act", "silu")
|
|
71
|
+
self.no_rope_layer_interval = getattr(text_config, "no_rope_layers",
|
|
72
|
+
[])
|
|
73
|
+
|
|
74
|
+
# interleave_moe_layer_step has a layer step of 2 to interleave MoE and dense layers for Llama-4-Maverick-17B-128E-Instruct.
|
|
75
|
+
# The default value is set to 1 for compatibility with Llama-4-Scout.
|
|
76
|
+
self.interleave_moe_layer_step = getattr(text_config,
|
|
77
|
+
"interleave_moe_layer_step",
|
|
78
|
+
1)
|
|
79
|
+
|
|
80
|
+
self.num_attention_heads = getattr(text_config, "num_attention_heads",
|
|
81
|
+
40)
|
|
82
|
+
self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
|
|
83
|
+
8)
|
|
84
|
+
self.head_dim = getattr(text_config, "head_dim", 128)
|
|
85
|
+
|
|
86
|
+
self.num_shared_experts = getattr(text_config, "num_experts_per_tok",
|
|
87
|
+
1)
|
|
88
|
+
self.rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
|
|
89
|
+
|
|
90
|
+
self.embedder = Embedder(vocab_size=self.vocab_size,
|
|
91
|
+
hidden_size=self.hidden_size,
|
|
92
|
+
dtype=dtype,
|
|
93
|
+
vd_sharding=(('data', 'expert', 'model'),
|
|
94
|
+
None),
|
|
95
|
+
rngs=self.rng,
|
|
96
|
+
random_init=force_random_weights)
|
|
97
|
+
|
|
98
|
+
self.layers = []
|
|
99
|
+
|
|
100
|
+
for i in range(self.num_layers):
|
|
101
|
+
# For Llama4-Scout, all layers are MoE layers.
|
|
102
|
+
# This can be adjusted for other variants.
|
|
103
|
+
is_moe_layer = (i + 1) % \
|
|
104
|
+
self.interleave_moe_layer_step == 0
|
|
105
|
+
|
|
106
|
+
# Llama-4-Scout config: It has "no_rope_layers": []
|
|
107
|
+
use_attention_rope = (i + 1) not in self.no_rope_layer_interval
|
|
108
|
+
|
|
109
|
+
router = Router(dtype=dtype,
|
|
110
|
+
hidden_size=self.hidden_size,
|
|
111
|
+
num_experts=self.num_local_experts,
|
|
112
|
+
num_experts_per_tok=1,
|
|
113
|
+
router_act="sigmoid",
|
|
114
|
+
rngs=self.rng,
|
|
115
|
+
activation_ffw_td=('data', None),
|
|
116
|
+
ed_sharding=(None, 'expert'),
|
|
117
|
+
random_init=force_random_weights)
|
|
118
|
+
|
|
119
|
+
custom_module = MoE(
|
|
120
|
+
dtype=dtype,
|
|
121
|
+
num_local_experts=self.num_local_experts,
|
|
122
|
+
apply_expert_weight_before_computation=True,
|
|
123
|
+
hidden_size=self.hidden_size,
|
|
124
|
+
intermediate_size_moe=self.intermediate_size_moe,
|
|
125
|
+
hidden_act=self.hidden_act,
|
|
126
|
+
router=router,
|
|
127
|
+
rngs=self.rng,
|
|
128
|
+
activation_ffw_td=('data', None),
|
|
129
|
+
activation_ffw_ted=('data', 'expert', None),
|
|
130
|
+
edf_sharding=('expert', None, 'model'),
|
|
131
|
+
efd_sharding=('expert', 'model', None),
|
|
132
|
+
random_init=force_random_weights
|
|
133
|
+
) if is_moe_layer else DenseFFW(
|
|
134
|
+
dtype=dtype,
|
|
135
|
+
hidden_act=self.hidden_act,
|
|
136
|
+
hidden_size=self.hidden_size,
|
|
137
|
+
intermediate_size=self.intermediate_size_mlp,
|
|
138
|
+
random_init=force_random_weights,
|
|
139
|
+
rngs=self.rng,
|
|
140
|
+
df_sharding=(None, 'model'),
|
|
141
|
+
fd_sharding=('model', None),
|
|
142
|
+
activation_ffw_td=('data', None))
|
|
143
|
+
|
|
144
|
+
attn = Llama4Attention(
|
|
145
|
+
hidden_size=self.hidden_size,
|
|
146
|
+
dtype=dtype,
|
|
147
|
+
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
|
148
|
+
num_attention_heads=self.num_attention_heads,
|
|
149
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
150
|
+
head_dim=self.head_dim,
|
|
151
|
+
rope_theta=500000.0,
|
|
152
|
+
# https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json
|
|
153
|
+
rope_scaling={
|
|
154
|
+
"scale_factor": 16.0,
|
|
155
|
+
"low_freq_factor": 1.0,
|
|
156
|
+
"high_freq_factor": 1.0,
|
|
157
|
+
"original_max_position_embeddings": 8192
|
|
158
|
+
},
|
|
159
|
+
rngs=self.rng,
|
|
160
|
+
rope_input_ordering="interleaved",
|
|
161
|
+
temperature_tuning=True,
|
|
162
|
+
temperature_tuning_scale=0.1,
|
|
163
|
+
temperature_tuning_floor_scale=8192,
|
|
164
|
+
use_qk_norm=True,
|
|
165
|
+
attention_chunk_size=None if use_attention_rope else 8192,
|
|
166
|
+
mesh=self.mesh,
|
|
167
|
+
random_init=force_random_weights,
|
|
168
|
+
activation_attention_td=('data', 'model'),
|
|
169
|
+
activation_q_td=('data', 'model'),
|
|
170
|
+
query_tnh=P('data', 'model', None),
|
|
171
|
+
keyvalue_skh=P('data', 'model', None),
|
|
172
|
+
activation_attention_out_td=('data', 'model'),
|
|
173
|
+
attn_o_tnh=P('data', 'model', None),
|
|
174
|
+
dnh_sharding=(None, 'model', None),
|
|
175
|
+
dkh_sharding=(None, 'model', None),
|
|
176
|
+
nhd_sharding=('model', None, None),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
shared_experts = DenseFFW(
|
|
180
|
+
dtype=dtype,
|
|
181
|
+
hidden_act=self.hidden_act,
|
|
182
|
+
hidden_size=self.hidden_size,
|
|
183
|
+
intermediate_size=self.num_shared_experts *
|
|
184
|
+
self.intermediate_size_moe,
|
|
185
|
+
rngs=self.rng,
|
|
186
|
+
random_init=force_random_weights,
|
|
187
|
+
df_sharding=(None, 'model'),
|
|
188
|
+
fd_sharding=('model', None),
|
|
189
|
+
activation_ffw_td=('data', None)) if is_moe_layer else None
|
|
190
|
+
|
|
191
|
+
pre_attention_norm = RMSNorm(
|
|
192
|
+
dims=self.hidden_size,
|
|
193
|
+
random_init=force_random_weights,
|
|
194
|
+
epsilon=self.rms_norm_eps,
|
|
195
|
+
rngs=self.rng,
|
|
196
|
+
with_scale=True,
|
|
197
|
+
dtype=dtype,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
pre_mlp_norm = RMSNorm(
|
|
201
|
+
dims=self.hidden_size,
|
|
202
|
+
epsilon=self.rms_norm_eps,
|
|
203
|
+
rngs=self.rng,
|
|
204
|
+
with_scale=True,
|
|
205
|
+
dtype=dtype,
|
|
206
|
+
random_init=force_random_weights,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
block = SharedExpertsTransformerBlock(
|
|
210
|
+
custom_module=custom_module,
|
|
211
|
+
attn=attn,
|
|
212
|
+
pre_attention_norm=pre_attention_norm,
|
|
213
|
+
pre_mlp_norm=pre_mlp_norm,
|
|
214
|
+
shared_experts=shared_experts,
|
|
215
|
+
use_attention_rope=use_attention_rope)
|
|
216
|
+
self.layers.append(block)
|
|
217
|
+
|
|
218
|
+
self.final_norm = RMSNorm(
|
|
219
|
+
dims=self.hidden_size,
|
|
220
|
+
epsilon=self.rms_norm_eps,
|
|
221
|
+
rngs=self.rng,
|
|
222
|
+
with_scale=True,
|
|
223
|
+
dtype=dtype,
|
|
224
|
+
random_init=force_random_weights,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
self.lm_head = LMhead(vocab_size=self.vocab_size,
|
|
228
|
+
hidden_size=self.hidden_size,
|
|
229
|
+
dtype=dtype,
|
|
230
|
+
rngs=self.rng,
|
|
231
|
+
vd_sharding=(('data', 'expert', 'model'), None),
|
|
232
|
+
dv_sharding=(None, ('data', 'expert', 'model')),
|
|
233
|
+
random_init=force_random_weights)
|
|
234
|
+
if self.is_verbose:
|
|
235
|
+
self._print_model_architecture()
|
|
236
|
+
|
|
237
|
+
def _print_model_architecture(self):
|
|
238
|
+
num_display_layers = max(self.interleave_moe_layer_step,
|
|
239
|
+
self.no_rope_layer_interval)
|
|
240
|
+
|
|
241
|
+
logger.info("### Embedding ###")
|
|
242
|
+
nnx.display(self.embedder)
|
|
243
|
+
|
|
244
|
+
logger.info(f"\n### First {num_display_layers} Layers ###")
|
|
245
|
+
# Loop through the slice and display each layer
|
|
246
|
+
for i, layer in enumerate(self.layers[:num_display_layers]):
|
|
247
|
+
logger.info(f"\n--- Layer {i} ---")
|
|
248
|
+
nnx.display(layer)
|
|
249
|
+
|
|
250
|
+
logger.info("\n### LM Head ###")
|
|
251
|
+
nnx.display(self.lm_head)
|
|
252
|
+
|
|
253
|
+
def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
|
|
254
|
+
# NOTE: Since we are using nnx.eval_shape to init the model,
|
|
255
|
+
# we have to pass dynamic arrays here for __call__'s usage.
|
|
256
|
+
self.rng = nnx.Rngs(rng)
|
|
257
|
+
|
|
258
|
+
weight_loader = Llama4WeightLoader(
|
|
259
|
+
vllm_config=self.vllm_config,
|
|
260
|
+
hidden_size=self.hidden_size,
|
|
261
|
+
attn_heads=self.num_attention_heads,
|
|
262
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
263
|
+
attn_head_dim=self.head_dim)
|
|
264
|
+
weight_loader.load_weights(self)
|
|
265
|
+
|
|
266
|
+
def __call__(
|
|
267
|
+
self,
|
|
268
|
+
kv_caches: List[jax.Array],
|
|
269
|
+
input_ids: jax.Array,
|
|
270
|
+
attention_metadata: AttentionMetadata,
|
|
271
|
+
*args,
|
|
272
|
+
) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]:
|
|
273
|
+
is_prefill = False
|
|
274
|
+
x_TD = self.embedder.encode(input_ids)
|
|
275
|
+
|
|
276
|
+
for (i, block) in enumerate(self.layers):
|
|
277
|
+
kv_cache = kv_caches[i]
|
|
278
|
+
new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
|
|
279
|
+
attention_metadata)
|
|
280
|
+
jax.block_until_ready(x_TD)
|
|
281
|
+
kv_caches[i] = new_kv_cache
|
|
282
|
+
|
|
283
|
+
final_activation_TD = self.final_norm(x_TD)
|
|
284
|
+
|
|
285
|
+
return kv_caches, final_activation_TD, []
|
|
286
|
+
|
|
287
|
+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
288
|
+
logits_TV = jnp.dot(hidden_states,
|
|
289
|
+
self.lm_head.input_embedding_table_DV.value)
|
|
290
|
+
return logits_TV
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class Llama4WeightLoader:
|
|
294
|
+
|
|
295
|
+
def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
|
|
296
|
+
num_key_value_heads, attn_head_dim):
|
|
297
|
+
self.names_and_weights_generator = model_weights_generator(
|
|
298
|
+
model_name_or_path=vllm_config.model_config.model,
|
|
299
|
+
framework="flax",
|
|
300
|
+
filter_regex="language_model",
|
|
301
|
+
download_dir=vllm_config.load_config.download_dir)
|
|
302
|
+
self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
|
|
303
|
+
False)
|
|
304
|
+
self.interleave_moe_layer_step = getattr(
|
|
305
|
+
vllm_config.model_config.hf_config.text_config,
|
|
306
|
+
"interleave_moe_layer_step", 1)
|
|
307
|
+
|
|
308
|
+
self.expert_prefix = "shared_expert."
|
|
309
|
+
self._transpose_map = {
|
|
310
|
+
"q_proj": (2, 0, 1),
|
|
311
|
+
"k_proj": (2, 0, 1),
|
|
312
|
+
"v_proj": (2, 0, 1),
|
|
313
|
+
"router": (1, 0),
|
|
314
|
+
f"{self.expert_prefix}down_proj": (1, 0),
|
|
315
|
+
f"{self.expert_prefix}gate_proj": (1, 0),
|
|
316
|
+
f"{self.expert_prefix}up_proj": (1, 0),
|
|
317
|
+
"feed_forward.down_proj": (1, 0),
|
|
318
|
+
"feed_forward.gate_proj": (1, 0),
|
|
319
|
+
"feed_forward.up_proj": (1, 0),
|
|
320
|
+
"o_proj": (1, 2, 0),
|
|
321
|
+
"lm_head": (1, 0),
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
self._weight_shape_map = {
|
|
325
|
+
"q_proj": (attn_heads, attn_head_dim, hidden_size),
|
|
326
|
+
"k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
327
|
+
"v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
328
|
+
# o_proj is inverted: https://github.com/huggingface/transformers/blob/v4.53.2/src/transformers/models/llama4/modeling_llama4.py#L298
|
|
329
|
+
"o_proj": (hidden_size, attn_heads, attn_head_dim),
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
# Set the mappings from loaded parameter keys to standardized names.
|
|
333
|
+
self._loaded_to_standardized_keys = {
|
|
334
|
+
"language_model.model.embed_tokens.weight":
|
|
335
|
+
"embedder.input_embedding_table_VD",
|
|
336
|
+
"language_model.lm_head.weight":
|
|
337
|
+
"lm_head.input_embedding_table_DV",
|
|
338
|
+
"language_model.model.norm.weight":
|
|
339
|
+
"final_norm.scale",
|
|
340
|
+
"language_model.model.layers.*.input_layernorm.weight":
|
|
341
|
+
"layers.*.pre_attention_norm.scale",
|
|
342
|
+
"language_model.model.layers.*.post_attention_layernorm.weight":
|
|
343
|
+
"layers.*.pre_mlp_norm.scale",
|
|
344
|
+
"language_model.model.layers.*.self_attn.q_proj.weight":
|
|
345
|
+
"layers.*.attn.kernel_q_proj_DNH",
|
|
346
|
+
"language_model.model.layers.*.self_attn.k_proj.weight":
|
|
347
|
+
"layers.*.attn.kernel_k_proj_DKH",
|
|
348
|
+
"language_model.model.layers.*.self_attn.v_proj.weight":
|
|
349
|
+
"layers.*.attn.kernel_v_proj_DKH",
|
|
350
|
+
"language_model.model.layers.*.self_attn.o_proj.weight":
|
|
351
|
+
"layers.*.attn.kernel_o_proj_NHD",
|
|
352
|
+
"language_model.model.layers.*.feed_forward.router.weight":
|
|
353
|
+
"layers.*.custom_module.router.kernel_DE",
|
|
354
|
+
"language_model.model.layers.*.feed_forward.experts.down_proj":
|
|
355
|
+
"layers.*.custom_module.kernel_down_proj_EFD",
|
|
356
|
+
"language_model.model.layers.*.feed_forward.experts.gate_up_proj":
|
|
357
|
+
"layers.*.custom_module.kernel_up_proj_EDF",
|
|
358
|
+
"language_model.model.layers.*.feed_forward.shared_expert.down_proj.weight":
|
|
359
|
+
"layers.*.shared_experts.kernel_down_proj_FD",
|
|
360
|
+
"language_model.model.layers.*.feed_forward.shared_expert.gate_proj.weight":
|
|
361
|
+
"layers.*.shared_experts.kernel_gating_DF",
|
|
362
|
+
"language_model.model.layers.*.feed_forward.shared_expert.up_proj.weight":
|
|
363
|
+
"layers.*.shared_experts.kernel_up_proj_DF",
|
|
364
|
+
"language_model.model.layers.*.feed_forward.down_proj.weight":
|
|
365
|
+
"layers.*.custom_module.kernel_down_proj_FD",
|
|
366
|
+
"language_model.model.layers.*.feed_forward.up_proj.weight":
|
|
367
|
+
"layers.*.custom_module.kernel_up_proj_DF",
|
|
368
|
+
"language_model.model.layers.*.feed_forward.gate_proj.weight":
|
|
369
|
+
"layers.*.custom_module.kernel_gating_DF",
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
|
|
373
|
+
# Find the corresponding model key using the HF key
|
|
374
|
+
if "layer" in loaded_key:
|
|
375
|
+
layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
|
|
376
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
|
|
377
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
378
|
+
layer_key, loaded_key)
|
|
379
|
+
mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
|
|
380
|
+
mapped_key)
|
|
381
|
+
else:
|
|
382
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
383
|
+
loaded_key, loaded_key)
|
|
384
|
+
return mapped_key
|
|
385
|
+
|
|
386
|
+
def _map_llama4_gate_up_proj(self, model_for_loading: nnx.Module,
|
|
387
|
+
model_params: nnx.State, loaded_name: str,
|
|
388
|
+
loaded_weight: jax.Array):
|
|
389
|
+
"""HF's gate_up_proj is a fused tensor of gate and up projections. It needs to be split."""
|
|
390
|
+
# gate_proj is first & up_proj is second
|
|
391
|
+
split_weights = jnp.split(loaded_weight, 2, axis=-1)
|
|
392
|
+
|
|
393
|
+
for split_type in ["gate", "up"]:
|
|
394
|
+
split_loaded_name = loaded_name.replace("gate_up_proj",
|
|
395
|
+
f"{split_type}_proj")
|
|
396
|
+
if split_type == "gate":
|
|
397
|
+
mapped_name = "layers.*.custom_module.kernel_gating_EDF"
|
|
398
|
+
loaded_weight = split_weights[0]
|
|
399
|
+
else:
|
|
400
|
+
mapped_name = "layers.*.custom_module.kernel_up_proj_EDF"
|
|
401
|
+
loaded_weight = split_weights[1]
|
|
402
|
+
|
|
403
|
+
layer_num = re.search(r"layers\.(\d+)", split_loaded_name).group(1)
|
|
404
|
+
mapped_name = re.sub(r"layers\.\*", f"layers.{layer_num}",
|
|
405
|
+
mapped_name)
|
|
406
|
+
mapped_model_weight = get_param(model_params, mapped_name)
|
|
407
|
+
|
|
408
|
+
if mapped_model_weight.value.shape != loaded_weight.shape:
|
|
409
|
+
raise ValueError(
|
|
410
|
+
f"Loaded shape for {split_loaded_name}: {loaded_weight.shape} "
|
|
411
|
+
f"does not match model shape for {mapped_name}: {mapped_model_weight.value.shape}!"
|
|
412
|
+
)
|
|
413
|
+
mapped_model_weight.value = shard_put(loaded_weight,
|
|
414
|
+
mapped_model_weight.sharding,
|
|
415
|
+
mesh=model_for_loading.mesh)
|
|
416
|
+
logger.debug(
|
|
417
|
+
f"{split_loaded_name}: {loaded_weight.shape} --> {mapped_name}: {mapped_model_weight.value.shape}"
|
|
418
|
+
)
|
|
419
|
+
if self.is_verbose:
|
|
420
|
+
print_param_info(mapped_model_weight, mapped_name)
|
|
421
|
+
|
|
422
|
+
def _get_layer_num(self, loaded_key: str) -> Optional[int]:
|
|
423
|
+
"""
|
|
424
|
+
Extracts the layer number from a HuggingFace weight key string.
|
|
425
|
+
Returns the layer number (int) or None if no layer number is found.
|
|
426
|
+
"""
|
|
427
|
+
match = re.search(r"layers\.(\d+)", loaded_key)
|
|
428
|
+
if match:
|
|
429
|
+
return int(match.group(1))
|
|
430
|
+
return None
|
|
431
|
+
|
|
432
|
+
def load_weights(self, model_for_loading: nnx.Module):
|
|
433
|
+
model_params = nnx.state(model_for_loading)
|
|
434
|
+
|
|
435
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
436
|
+
for loaded_name, loaded_weight in self.names_and_weights_generator:
|
|
437
|
+
is_moe_layer = False
|
|
438
|
+
layer_num = self._get_layer_num(loaded_name)
|
|
439
|
+
|
|
440
|
+
if layer_num is not None:
|
|
441
|
+
is_moe_layer = (layer_num + 1) % \
|
|
442
|
+
self.interleave_moe_layer_step == 0
|
|
443
|
+
self.expert_prefix = "shared_expert." if is_moe_layer else ""
|
|
444
|
+
|
|
445
|
+
if "gate_up_proj" in loaded_name:
|
|
446
|
+
self._map_llama4_gate_up_proj(model_for_loading,
|
|
447
|
+
model_params, loaded_name,
|
|
448
|
+
loaded_weight)
|
|
449
|
+
continue
|
|
450
|
+
mapped_name = self.map_loaded_to_standardized_name(loaded_name)
|
|
451
|
+
model_weight = get_param(model_params, mapped_name)
|
|
452
|
+
|
|
453
|
+
if not loaded_name.endswith(".bias"):
|
|
454
|
+
loaded_weight = reshape_params(loaded_name, loaded_weight,
|
|
455
|
+
self._weight_shape_map)
|
|
456
|
+
loaded_weight = transpose_params(loaded_name,
|
|
457
|
+
loaded_weight,
|
|
458
|
+
self._transpose_map)
|
|
459
|
+
if model_weight.value.shape != loaded_weight.shape:
|
|
460
|
+
raise ValueError(
|
|
461
|
+
f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
|
|
462
|
+
f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
|
|
463
|
+
)
|
|
464
|
+
logger.debug(
|
|
465
|
+
f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
|
|
466
|
+
)
|
|
467
|
+
model_weight.value = shard_put(loaded_weight,
|
|
468
|
+
model_weight.sharding,
|
|
469
|
+
mesh=model_for_loading.mesh)
|
|
470
|
+
if self.is_verbose:
|
|
471
|
+
print_param_info(model_weight, loaded_name)
|
|
472
|
+
|
|
473
|
+
nnx.update(model_for_loading, model_params)
|