tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,629 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from flax import nnx
|
|
7
|
+
from flax.typing import PRNGKey
|
|
8
|
+
from jax.sharding import Mesh
|
|
9
|
+
from jax.sharding import PartitionSpec as P
|
|
10
|
+
from vllm.config import VllmConfig
|
|
11
|
+
|
|
12
|
+
from tpu_inference.layers.jax.attention.attention import AttentionMetadata
|
|
13
|
+
from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
|
|
14
|
+
from tpu_inference.layers.jax.constants import KVCacheType
|
|
15
|
+
from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
|
|
16
|
+
from tpu_inference.layers.jax.misc import shard_put
|
|
17
|
+
from tpu_inference.layers.jax.moe.moe import MoE, Router
|
|
18
|
+
from tpu_inference.layers.jax.transformer_block import \
|
|
19
|
+
SharedExpertsTransformerBlock
|
|
20
|
+
from tpu_inference.logger import init_logger
|
|
21
|
+
from tpu_inference.models.jax.utils.weight_utils import (
|
|
22
|
+
convert_torch_to_jax_with_view, get_param, model_weights_generator,
|
|
23
|
+
print_param_info, reshape_params, transpose_params)
|
|
24
|
+
|
|
25
|
+
logger = init_logger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Llama4ForCausalLM(nnx.Module):
|
|
29
|
+
|
|
30
|
+
def __init__(self,
|
|
31
|
+
vllm_config: VllmConfig,
|
|
32
|
+
rng: PRNGKey,
|
|
33
|
+
mesh: Mesh,
|
|
34
|
+
force_random_weights: bool = False):
|
|
35
|
+
assert mesh is not None
|
|
36
|
+
|
|
37
|
+
self.vllm_config = vllm_config
|
|
38
|
+
model_config = vllm_config.model_config
|
|
39
|
+
text_config = model_config.hf_config.text_config
|
|
40
|
+
|
|
41
|
+
self.rng = nnx.Rngs(rng)
|
|
42
|
+
self.mesh = mesh
|
|
43
|
+
self.is_verbose = getattr(self.vllm_config.additional_config,
|
|
44
|
+
"is_verbose", False)
|
|
45
|
+
|
|
46
|
+
# Currently the runner will always set a mesh, so the custom default sharding (when
|
|
47
|
+
# no sharding is set in vllm config) doesn't take effect.
|
|
48
|
+
# TODO(fhzhang): figure out whether we need to actually enable this.
|
|
49
|
+
# strategy_dict = {"tensor_parallelism": 4, "expert_parallelism": 2}
|
|
50
|
+
|
|
51
|
+
self.vocab_size = model_config.get_vocab_size()
|
|
52
|
+
self.hidden_size = model_config.get_hidden_size()
|
|
53
|
+
|
|
54
|
+
dtype: jnp.dtype = jnp.bfloat16
|
|
55
|
+
|
|
56
|
+
self.num_layers: int = getattr(text_config, "num_hidden_layers", 48)
|
|
57
|
+
|
|
58
|
+
self.intermediate_size_moe: int = getattr(text_config,
|
|
59
|
+
"intermediate_size", 8192)
|
|
60
|
+
self.intermediate_size_mlp = getattr(text_config,
|
|
61
|
+
"intermediate_size_mlp", 16384)
|
|
62
|
+
|
|
63
|
+
# num_local_experts: uses 16 experts for Llama-4-Scout-17B-16E-Instruct and uses 128 experts Llama-4-Maverick-17B-128E-Instruct.
|
|
64
|
+
# The default value is set to 16 for compatibility with Llama-4-Scout.
|
|
65
|
+
self.num_local_experts: int = getattr(text_config, "num_local_experts",
|
|
66
|
+
16)
|
|
67
|
+
self.hidden_act: str = getattr(text_config, "hidden_act", "silu")
|
|
68
|
+
self.no_rope_layer_interval = 4
|
|
69
|
+
|
|
70
|
+
# interleave_moe_layer_step has a layer step of 2 to interleave MoE and dense layers for Llama-4-Maverick-17B-128E-Instruct.
|
|
71
|
+
# The default value is set to 1 for compatibility with Llama-4-Scout.
|
|
72
|
+
self.interleave_moe_layer_step = getattr(text_config,
|
|
73
|
+
"interleave_moe_layer_step",
|
|
74
|
+
1)
|
|
75
|
+
|
|
76
|
+
self.num_attention_heads = getattr(text_config, "num_attention_heads",
|
|
77
|
+
40)
|
|
78
|
+
self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
|
|
79
|
+
8)
|
|
80
|
+
self.head_dim = getattr(text_config, "head_dim", 128)
|
|
81
|
+
|
|
82
|
+
self.num_shared_experts = getattr(text_config, "num_experts_per_tok",
|
|
83
|
+
1)
|
|
84
|
+
self.rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
|
|
85
|
+
|
|
86
|
+
self.rope_scaling = getattr(text_config, "rope_scaling", None)
|
|
87
|
+
if self.rope_scaling:
|
|
88
|
+
self.rope_scaling["scale_factor"] = self.rope_scaling.pop("factor")
|
|
89
|
+
|
|
90
|
+
self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
|
|
91
|
+
|
|
92
|
+
self.embedder = Embedder(vocab_size=self.vocab_size,
|
|
93
|
+
hidden_size=self.hidden_size,
|
|
94
|
+
dtype=dtype,
|
|
95
|
+
vd_sharding=(('data', 'expert', 'model'),
|
|
96
|
+
None),
|
|
97
|
+
rngs=self.rng,
|
|
98
|
+
random_init=force_random_weights)
|
|
99
|
+
|
|
100
|
+
self.layers = []
|
|
101
|
+
|
|
102
|
+
for i in range(self.num_layers):
|
|
103
|
+
# For Llama4-Scout, all layers are MoE layers.
|
|
104
|
+
# This can be adjusted for other variants.
|
|
105
|
+
is_moe_layer = (i + 1) % \
|
|
106
|
+
self.interleave_moe_layer_step == 0
|
|
107
|
+
|
|
108
|
+
# Llama-4-Scout config: It has "no_rope_layers": []
|
|
109
|
+
use_attention_rope = (i + 1) % self.no_rope_layer_interval != 0
|
|
110
|
+
|
|
111
|
+
router = Router(dtype=dtype,
|
|
112
|
+
hidden_size=self.hidden_size,
|
|
113
|
+
num_experts=self.num_local_experts,
|
|
114
|
+
num_experts_per_tok=1,
|
|
115
|
+
router_act="sigmoid",
|
|
116
|
+
rngs=self.rng,
|
|
117
|
+
activation_ffw_td=('data', None),
|
|
118
|
+
ed_sharding=(None, None),
|
|
119
|
+
random_init=force_random_weights)
|
|
120
|
+
|
|
121
|
+
moe_ffw = MoE(
|
|
122
|
+
dtype=dtype,
|
|
123
|
+
num_local_experts=self.num_local_experts,
|
|
124
|
+
apply_expert_weight_before_computation=True,
|
|
125
|
+
hidden_size=self.hidden_size,
|
|
126
|
+
intermediate_size_moe=self.intermediate_size_moe,
|
|
127
|
+
hidden_act=self.hidden_act,
|
|
128
|
+
router=router,
|
|
129
|
+
rngs=self.rng,
|
|
130
|
+
activation_ffw_td=('data', None),
|
|
131
|
+
activation_ffw_ted=('data', 'expert', None),
|
|
132
|
+
edf_sharding=('model', None, None),
|
|
133
|
+
efd_sharding=('model', None, None),
|
|
134
|
+
random_init=force_random_weights) if is_moe_layer else None
|
|
135
|
+
|
|
136
|
+
dense_ffw = DenseFFW(
|
|
137
|
+
dtype=dtype,
|
|
138
|
+
hidden_act=self.hidden_act,
|
|
139
|
+
hidden_size=self.hidden_size,
|
|
140
|
+
intermediate_size=self.intermediate_size_mlp,
|
|
141
|
+
random_init=force_random_weights,
|
|
142
|
+
rngs=self.rng,
|
|
143
|
+
df_sharding=(None, 'model'),
|
|
144
|
+
fd_sharding=('model', None),
|
|
145
|
+
activation_ffw_td=('data', None)) if not is_moe_layer else None
|
|
146
|
+
|
|
147
|
+
attn = Llama4Attention(
|
|
148
|
+
hidden_size=self.hidden_size,
|
|
149
|
+
dtype=dtype,
|
|
150
|
+
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
|
151
|
+
num_attention_heads=self.num_attention_heads,
|
|
152
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
153
|
+
head_dim=self.head_dim,
|
|
154
|
+
rope_theta=500000.0,
|
|
155
|
+
# https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json
|
|
156
|
+
rope_scaling=self.rope_scaling,
|
|
157
|
+
rngs=self.rng,
|
|
158
|
+
rope_input_ordering="interleaved",
|
|
159
|
+
temperature_tuning=True,
|
|
160
|
+
temperature_tuning_scale=0.1,
|
|
161
|
+
temperature_tuning_floor_scale=8192,
|
|
162
|
+
use_qk_norm=self.use_qk_norm,
|
|
163
|
+
attention_chunk_size=None if use_attention_rope else 8192,
|
|
164
|
+
mesh=self.mesh,
|
|
165
|
+
random_init=force_random_weights,
|
|
166
|
+
activation_attention_td=('data', 'model'),
|
|
167
|
+
activation_q_td=('data', 'model'),
|
|
168
|
+
query_tnh=P('data', 'model', None),
|
|
169
|
+
keyvalue_skh=P('data', 'model', None),
|
|
170
|
+
activation_attention_out_td=('data', 'model'),
|
|
171
|
+
attn_o_tnh=P('data', 'model', None),
|
|
172
|
+
dnh_sharding=(None, 'model', None),
|
|
173
|
+
dkh_sharding=(None, 'model', None),
|
|
174
|
+
nhd_sharding=('model', None, None),
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
shared_experts = DenseFFW(
|
|
178
|
+
dtype=dtype,
|
|
179
|
+
hidden_act=self.hidden_act,
|
|
180
|
+
hidden_size=self.hidden_size,
|
|
181
|
+
intermediate_size=self.num_shared_experts *
|
|
182
|
+
self.intermediate_size_moe,
|
|
183
|
+
rngs=self.rng,
|
|
184
|
+
random_init=force_random_weights,
|
|
185
|
+
df_sharding=(None, 'model'),
|
|
186
|
+
fd_sharding=('model', None),
|
|
187
|
+
activation_ffw_td=('data', None)) if is_moe_layer else None
|
|
188
|
+
|
|
189
|
+
pre_attention_norm = RMSNorm(
|
|
190
|
+
dims=self.hidden_size,
|
|
191
|
+
random_init=force_random_weights,
|
|
192
|
+
epsilon=self.rms_norm_eps,
|
|
193
|
+
rngs=self.rng,
|
|
194
|
+
with_scale=True,
|
|
195
|
+
dtype=dtype,
|
|
196
|
+
activation_ffw_td=('data', None),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
pre_mlp_norm = RMSNorm(
|
|
200
|
+
dims=self.hidden_size,
|
|
201
|
+
epsilon=self.rms_norm_eps,
|
|
202
|
+
rngs=self.rng,
|
|
203
|
+
with_scale=True,
|
|
204
|
+
dtype=dtype,
|
|
205
|
+
random_init=force_random_weights,
|
|
206
|
+
activation_ffw_td=('data', None),
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
block = SharedExpertsTransformerBlock(
|
|
210
|
+
moe_ffw=moe_ffw if is_moe_layer else None,
|
|
211
|
+
dense_ffw=dense_ffw if not is_moe_layer else None,
|
|
212
|
+
shared_experts=shared_experts if is_moe_layer else None,
|
|
213
|
+
attn=attn,
|
|
214
|
+
pre_attention_norm=pre_attention_norm,
|
|
215
|
+
pre_mlp_norm=pre_mlp_norm,
|
|
216
|
+
use_attention_rope=use_attention_rope)
|
|
217
|
+
self.layers.append(block)
|
|
218
|
+
|
|
219
|
+
self.final_norm = RMSNorm(
|
|
220
|
+
dims=self.hidden_size,
|
|
221
|
+
epsilon=self.rms_norm_eps,
|
|
222
|
+
rngs=self.rng,
|
|
223
|
+
with_scale=True,
|
|
224
|
+
dtype=dtype,
|
|
225
|
+
random_init=force_random_weights,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
self.lm_head = LMhead(vocab_size=self.vocab_size,
|
|
229
|
+
hidden_size=self.hidden_size,
|
|
230
|
+
dtype=dtype,
|
|
231
|
+
rngs=self.rng,
|
|
232
|
+
vd_sharding=(('data', 'expert', 'model'), None),
|
|
233
|
+
dv_sharding=(None, ('data', 'expert', 'model')),
|
|
234
|
+
random_init=force_random_weights)
|
|
235
|
+
if self.is_verbose:
|
|
236
|
+
self._print_model_architecture()
|
|
237
|
+
|
|
238
|
+
def _print_model_architecture(self):
|
|
239
|
+
num_display_layers = max(self.interleave_moe_layer_step,
|
|
240
|
+
self.no_rope_layer_interval)
|
|
241
|
+
|
|
242
|
+
logger.info("### Embedding ###")
|
|
243
|
+
nnx.display(self.embedder)
|
|
244
|
+
|
|
245
|
+
logger.info(f"\n### First {num_display_layers} Layers ###")
|
|
246
|
+
# Loop through the slice and display each layer
|
|
247
|
+
for i, layer in enumerate(self.layers[:num_display_layers]):
|
|
248
|
+
logger.info(f"\n--- Layer {i} ---")
|
|
249
|
+
nnx.display(layer)
|
|
250
|
+
|
|
251
|
+
logger.info("\n### LM Head ###")
|
|
252
|
+
nnx.display(self.lm_head)
|
|
253
|
+
|
|
254
|
+
def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
|
|
255
|
+
# NOTE: Since we are using nnx.eval_shape to init the model,
|
|
256
|
+
# we have to pass dynamic arrays here for __call__'s usage.
|
|
257
|
+
self.rng = nnx.Rngs(rng)
|
|
258
|
+
|
|
259
|
+
weight_loader = Llama4WeightLoader(
|
|
260
|
+
vllm_config=self.vllm_config,
|
|
261
|
+
hidden_size=self.hidden_size,
|
|
262
|
+
attn_heads=self.num_attention_heads,
|
|
263
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
264
|
+
attn_head_dim=self.head_dim)
|
|
265
|
+
weight_loader.load_weights(self)
|
|
266
|
+
|
|
267
|
+
def __call__(
|
|
268
|
+
self,
|
|
269
|
+
kv_caches: List[jax.Array],
|
|
270
|
+
input_ids: jax.Array,
|
|
271
|
+
attention_metadata: AttentionMetadata,
|
|
272
|
+
*args,
|
|
273
|
+
) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]:
|
|
274
|
+
is_prefill = False
|
|
275
|
+
x_TD = self.embedder.encode(input_ids)
|
|
276
|
+
|
|
277
|
+
for (i, block) in enumerate(self.layers):
|
|
278
|
+
kv_cache = kv_caches[i]
|
|
279
|
+
new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
|
|
280
|
+
attention_metadata)
|
|
281
|
+
jax.block_until_ready(x_TD)
|
|
282
|
+
kv_caches[i] = new_kv_cache
|
|
283
|
+
|
|
284
|
+
final_activation_TD = self.final_norm(x_TD)
|
|
285
|
+
|
|
286
|
+
return kv_caches, final_activation_TD, []
|
|
287
|
+
|
|
288
|
+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
289
|
+
logits_TV = jnp.dot(hidden_states,
|
|
290
|
+
self.lm_head.input_embedding_table_DV.value)
|
|
291
|
+
return logits_TV
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class Llama4WeightLoader:
|
|
295
|
+
|
|
296
|
+
def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
|
|
297
|
+
num_key_value_heads, attn_head_dim):
|
|
298
|
+
self.names_and_weights_generator = model_weights_generator(
|
|
299
|
+
model_name_or_path=vllm_config.model_config.model,
|
|
300
|
+
framework="pt",
|
|
301
|
+
filter_regex="language_model",
|
|
302
|
+
download_dir=vllm_config.load_config.download_dir)
|
|
303
|
+
self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
|
|
304
|
+
False)
|
|
305
|
+
self.interleave_moe_layer_step = getattr(
|
|
306
|
+
vllm_config.model_config.hf_config.text_config,
|
|
307
|
+
"interleave_moe_layer_step", 1)
|
|
308
|
+
|
|
309
|
+
self.quantization_config = getattr(vllm_config.model_config.hf_config,
|
|
310
|
+
"quantization_config", None)
|
|
311
|
+
self.expert_weights_buffer = {}
|
|
312
|
+
self.expert_prefix = "shared_expert."
|
|
313
|
+
|
|
314
|
+
transpose_mappings_to_quantization = {
|
|
315
|
+
"down_proj": (1, 0),
|
|
316
|
+
"gate_proj": (1, 0),
|
|
317
|
+
"up_proj": (1, 0),
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
self._transpose_map = {
|
|
321
|
+
"q_proj": (2, 0, 1),
|
|
322
|
+
"k_proj": (2, 0, 1),
|
|
323
|
+
"v_proj": (2, 0, 1),
|
|
324
|
+
"router": (1, 0),
|
|
325
|
+
f"{self.expert_prefix}down_proj": (1, 0),
|
|
326
|
+
f"{self.expert_prefix}gate_proj": (1, 0),
|
|
327
|
+
f"{self.expert_prefix}up_proj": (1, 0),
|
|
328
|
+
"feed_forward.down_proj": (1, 0),
|
|
329
|
+
"feed_forward.gate_proj": (1, 0),
|
|
330
|
+
"feed_forward.up_proj": (1, 0),
|
|
331
|
+
"o_proj": (1, 2, 0),
|
|
332
|
+
"lm_head": (1, 0),
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
if self.quantization_config and self.expert_prefix:
|
|
336
|
+
self._transpose_map.update(transpose_mappings_to_quantization)
|
|
337
|
+
|
|
338
|
+
self._weight_shape_map = {
|
|
339
|
+
"q_proj": (attn_heads, attn_head_dim, hidden_size),
|
|
340
|
+
"k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
341
|
+
"v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
342
|
+
# o_proj is inverted: https://github.com/huggingface/transformers/blob/v4.53.2/src/transformers/models/llama4/modeling_llama4.py#L298
|
|
343
|
+
"o_proj": (hidden_size, attn_heads, attn_head_dim),
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
# Set the mappings from loaded parameter keys to standardized names.\
|
|
347
|
+
# 1. EXPERT_MAPPINGS_FUSED: Used for non-quantized (e.g., BF16) checkpoints.
|
|
348
|
+
# - This format typically comes from standard checkpoints where 'gate' and 'up' projection weights might be combined (FUSED) into a single tensor.
|
|
349
|
+
# - Expert weights are usually stacked, with the expert dimension (E) being the first dimension.
|
|
350
|
+
EXPERT_MAPPINGS_FUSED = {
|
|
351
|
+
"language_model.model.layers.*.feed_forward.experts.down_proj":
|
|
352
|
+
"layers.*.moe_ffw.kernel_down_proj_EFD",
|
|
353
|
+
"language_model.model.layers.*.feed_forward.experts.gate_up_proj":
|
|
354
|
+
"layers.*.moe_ffw.kernel_up_proj_EDF",
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
# 2. EXPERT_MAPPINGS_UNFUSED: Specifically designed for quantized checkpoints (e.g., FP8).
|
|
358
|
+
# - Quantized checkpoints store each expert's weights separately and explicitly separate the 'weight' (quantized value) from the 'weight_scale' (quantization scale).
|
|
359
|
+
# - The mapping captures both the `.weight` and `.weight_scale` components. This allows the loader to aggregate (stack) the individual expert weights and scales.
|
|
360
|
+
EXPERT_MAPPINGS_UNFUSED = {
|
|
361
|
+
"language_model.model.layers.*.feed_forward.experts.*.down_proj.weight":
|
|
362
|
+
"layers.*.moe_ffw.kernel_down_proj_EFD",
|
|
363
|
+
"language_model.model.layers.*.feed_forward.experts.*.down_proj.weight_scale":
|
|
364
|
+
"layers.*.moe_ffw.kernel_down_proj_EFD",
|
|
365
|
+
"language_model.model.layers.*.feed_forward.experts.*.gate_proj.weight":
|
|
366
|
+
"layers.*.moe_ffw.kernel_gating_EDF",
|
|
367
|
+
"language_model.model.layers.*.feed_forward.experts.*.gate_proj.weight_scale":
|
|
368
|
+
"layers.*.moe_ffw.kernel_gating_EDF",
|
|
369
|
+
"language_model.model.layers.*.feed_forward.experts.*.up_proj.weight":
|
|
370
|
+
"layers.*.moe_ffw.kernel_up_proj_EDF",
|
|
371
|
+
"language_model.model.layers.*.feed_forward.experts.*.up_proj.weight_scale":
|
|
372
|
+
"layers.*.moe_ffw.kernel_up_proj_EDF",
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
self._loaded_to_standardized_keys = {
|
|
376
|
+
"language_model.model.embed_tokens.weight":
|
|
377
|
+
"embedder.input_embedding_table_VD",
|
|
378
|
+
"language_model.lm_head.weight":
|
|
379
|
+
"lm_head.input_embedding_table_DV",
|
|
380
|
+
"language_model.model.norm.weight":
|
|
381
|
+
"final_norm.scale",
|
|
382
|
+
"language_model.model.layers.*.input_layernorm.weight":
|
|
383
|
+
"layers.*.pre_attention_norm.scale",
|
|
384
|
+
"language_model.model.layers.*.post_attention_layernorm.weight":
|
|
385
|
+
"layers.*.pre_mlp_norm.scale",
|
|
386
|
+
"language_model.model.layers.*.self_attn.q_proj.weight":
|
|
387
|
+
"layers.*.attn.kernel_q_proj_DNH",
|
|
388
|
+
"language_model.model.layers.*.self_attn.k_proj.weight":
|
|
389
|
+
"layers.*.attn.kernel_k_proj_DKH",
|
|
390
|
+
"language_model.model.layers.*.self_attn.v_proj.weight":
|
|
391
|
+
"layers.*.attn.kernel_v_proj_DKH",
|
|
392
|
+
"language_model.model.layers.*.self_attn.o_proj.weight":
|
|
393
|
+
"layers.*.attn.kernel_o_proj_NHD",
|
|
394
|
+
"language_model.model.layers.*.feed_forward.router.weight":
|
|
395
|
+
"layers.*.moe_ffw.router.kernel_DE",
|
|
396
|
+
# shared experts
|
|
397
|
+
"language_model.model.layers.*.feed_forward.shared_expert.down_proj.weight":
|
|
398
|
+
"layers.*.shared_experts.kernel_down_proj_FD",
|
|
399
|
+
"language_model.model.layers.*.feed_forward.shared_expert.gate_proj.weight":
|
|
400
|
+
"layers.*.shared_experts.kernel_gating_DF",
|
|
401
|
+
"language_model.model.layers.*.feed_forward.shared_expert.up_proj.weight":
|
|
402
|
+
"layers.*.shared_experts.kernel_up_proj_DF",
|
|
403
|
+
# dense layers
|
|
404
|
+
"language_model.model.layers.*.feed_forward.down_proj.weight":
|
|
405
|
+
"layers.*.dense_ffw.kernel_down_proj_FD",
|
|
406
|
+
"language_model.model.layers.*.feed_forward.up_proj.weight":
|
|
407
|
+
"layers.*.dense_ffw.kernel_up_proj_DF",
|
|
408
|
+
"language_model.model.layers.*.feed_forward.gate_proj.weight":
|
|
409
|
+
"layers.*.dense_ffw.kernel_gating_DF",
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
if self.quantization_config is None:
|
|
413
|
+
self._loaded_to_standardized_keys.update(EXPERT_MAPPINGS_FUSED)
|
|
414
|
+
else:
|
|
415
|
+
self._loaded_to_standardized_keys.update(EXPERT_MAPPINGS_UNFUSED)
|
|
416
|
+
|
|
417
|
+
def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
|
|
418
|
+
# Find the corresponding model key using the HF key
|
|
419
|
+
if "layer" in loaded_key:
|
|
420
|
+
layer_num = self._get_layer_num(loaded_key)
|
|
421
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
|
|
422
|
+
|
|
423
|
+
expert_match = re.search(r"experts\.(\d+)", layer_key)
|
|
424
|
+
if expert_match:
|
|
425
|
+
# Key for lookup eg: layers.*.feed_forward.experts.*.down_proj.weight
|
|
426
|
+
layer_key = re.sub(r"experts\.\d+", "experts.*", layer_key)
|
|
427
|
+
|
|
428
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
429
|
+
layer_key, loaded_key)
|
|
430
|
+
mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
|
|
431
|
+
mapped_key)
|
|
432
|
+
else:
|
|
433
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
434
|
+
loaded_key, loaded_key)
|
|
435
|
+
return mapped_key
|
|
436
|
+
|
|
437
|
+
def _map_llama4_gate_up_proj(self, model_for_loading: nnx.Module,
|
|
438
|
+
model_params: nnx.State, loaded_name: str,
|
|
439
|
+
loaded_weight: jax.Array):
|
|
440
|
+
"""HF's gate_up_proj is a fused tensor of gate and up projections. It needs to be split."""
|
|
441
|
+
|
|
442
|
+
cast_type = jnp.dtype(jnp.bfloat16)
|
|
443
|
+
# loaded_weight is a jax.Array when framework="flax", otherwise it's bfloat16
|
|
444
|
+
if not isinstance(loaded_weight, jax.Array):
|
|
445
|
+
loaded_weight = convert_torch_to_jax_with_view(
|
|
446
|
+
loaded_weight, cast_type)
|
|
447
|
+
|
|
448
|
+
split_weights = jnp.split(loaded_weight, 2, axis=-1)
|
|
449
|
+
layer_num = self._get_layer_num(loaded_name)
|
|
450
|
+
|
|
451
|
+
for split_type in ["gate", "up"]:
|
|
452
|
+
split_loaded_name = loaded_name.replace("gate_up_proj",
|
|
453
|
+
f"{split_type}_proj")
|
|
454
|
+
if split_type == "gate":
|
|
455
|
+
mapped_name = "layers.*.moe_ffw.kernel_gating_EDF"
|
|
456
|
+
loaded_weight = split_weights[0]
|
|
457
|
+
else:
|
|
458
|
+
mapped_name = "layers.*.moe_ffw.kernel_up_proj_EDF"
|
|
459
|
+
loaded_weight = split_weights[1]
|
|
460
|
+
|
|
461
|
+
mapped_name = re.sub(r"layers\.\*", f"layers.{layer_num}",
|
|
462
|
+
mapped_name)
|
|
463
|
+
|
|
464
|
+
mapped_model_weight = get_param(model_params, mapped_name)
|
|
465
|
+
|
|
466
|
+
if mapped_model_weight.value.shape != loaded_weight.shape:
|
|
467
|
+
raise ValueError(
|
|
468
|
+
f"Loaded shape for {split_loaded_name}: {loaded_weight.shape} "
|
|
469
|
+
f"does not match model shape for {mapped_name}: {mapped_model_weight.value.shape}!"
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
mapped_model_weight.value = shard_put(loaded_weight,
|
|
473
|
+
mapped_model_weight.sharding,
|
|
474
|
+
mesh=model_for_loading.mesh)
|
|
475
|
+
logger.debug(
|
|
476
|
+
f"{split_loaded_name}: {loaded_weight.shape} --> {mapped_name}: {mapped_model_weight.value.shape}"
|
|
477
|
+
)
|
|
478
|
+
if self.is_verbose:
|
|
479
|
+
print_param_info(mapped_model_weight, mapped_name)
|
|
480
|
+
|
|
481
|
+
def _get_layer_num(self, loaded_key: str) -> Optional[int]:
|
|
482
|
+
"""
|
|
483
|
+
Extracts the layer number from a HuggingFace weight key string.
|
|
484
|
+
Returns the layer number (int) or None if no layer number is found.
|
|
485
|
+
"""
|
|
486
|
+
match = re.search(r"layers\.(\d+)", loaded_key)
|
|
487
|
+
if match:
|
|
488
|
+
return int(match.group(1))
|
|
489
|
+
return None
|
|
490
|
+
|
|
491
|
+
def _get_expert_num(self, loaded_key: str) -> Optional[int]:
|
|
492
|
+
"""
|
|
493
|
+
Extracts the expect number from a HuggingFace weight key string.
|
|
494
|
+
Returns the expect number (int) or None if no expect number is found.
|
|
495
|
+
"""
|
|
496
|
+
match = re.search(r"experts\.(\d+)\.", loaded_key)
|
|
497
|
+
if match:
|
|
498
|
+
return int(match.group(1))
|
|
499
|
+
return None
|
|
500
|
+
|
|
501
|
+
def load_weights(self, model_for_loading: nnx.Module):
|
|
502
|
+
model_params = nnx.state(model_for_loading)
|
|
503
|
+
|
|
504
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
505
|
+
for loaded_name, loaded_weight in self.names_and_weights_generator:
|
|
506
|
+
is_moe_layer = False
|
|
507
|
+
layer_num = self._get_layer_num(loaded_name)
|
|
508
|
+
expert_num = self._get_expert_num(loaded_name)
|
|
509
|
+
# Quantized (FP8) checkpoints unstack the expert weights, while unquantized (BF16) checkpoints keep them stacked.
|
|
510
|
+
is_unfused_expert = self.quantization_config is not None and expert_num is not None
|
|
511
|
+
is_scale = loaded_name.endswith(".weight_scale")
|
|
512
|
+
|
|
513
|
+
if is_unfused_expert:
|
|
514
|
+
mapped_name = self.map_loaded_to_standardized_name(
|
|
515
|
+
loaded_name)
|
|
516
|
+
model_weight = get_param(model_params, mapped_name)
|
|
517
|
+
|
|
518
|
+
if is_scale:
|
|
519
|
+
cast_type = model_weight.array.scale.value.dtype
|
|
520
|
+
else:
|
|
521
|
+
cast_type = model_weight.array.qvalue.value.dtype
|
|
522
|
+
|
|
523
|
+
loaded_weight = convert_torch_to_jax_with_view(
|
|
524
|
+
loaded_weight, cast_type)
|
|
525
|
+
loaded_weight = transpose_params(loaded_name,
|
|
526
|
+
loaded_weight,
|
|
527
|
+
self._transpose_map)
|
|
528
|
+
|
|
529
|
+
buffer_key = f"{mapped_name}_{'scale' if is_scale else 'qvalue'}"
|
|
530
|
+
if buffer_key not in self.expert_weights_buffer:
|
|
531
|
+
self.expert_weights_buffer[buffer_key] = {}
|
|
532
|
+
self.expert_weights_buffer[buffer_key][
|
|
533
|
+
expert_num] = loaded_weight
|
|
534
|
+
continue
|
|
535
|
+
|
|
536
|
+
if layer_num is not None:
|
|
537
|
+
is_moe_layer = (layer_num + 1) % \
|
|
538
|
+
self.interleave_moe_layer_step == 0
|
|
539
|
+
self.expert_prefix = "shared_expert." if is_moe_layer else ""
|
|
540
|
+
|
|
541
|
+
if "gate_up_proj" in loaded_name:
|
|
542
|
+
self._map_llama4_gate_up_proj(model_for_loading,
|
|
543
|
+
model_params, loaded_name,
|
|
544
|
+
loaded_weight)
|
|
545
|
+
continue
|
|
546
|
+
|
|
547
|
+
mapped_name = self.map_loaded_to_standardized_name(loaded_name)
|
|
548
|
+
model_weight = get_param(model_params, mapped_name)
|
|
549
|
+
|
|
550
|
+
cast_type = model_weight.value.dtype
|
|
551
|
+
if not isinstance(loaded_weight, jax.Array):
|
|
552
|
+
logger.debug(
|
|
553
|
+
f"Converting PyTorch tensor {loaded_name} to JAX {cast_type}"
|
|
554
|
+
)
|
|
555
|
+
loaded_weight = convert_torch_to_jax_with_view(
|
|
556
|
+
loaded_weight, cast_type)
|
|
557
|
+
|
|
558
|
+
if not loaded_name.endswith(".bias"):
|
|
559
|
+
loaded_weight = reshape_params(loaded_name, loaded_weight,
|
|
560
|
+
self._weight_shape_map)
|
|
561
|
+
loaded_weight = transpose_params(loaded_name,
|
|
562
|
+
loaded_weight,
|
|
563
|
+
self._transpose_map)
|
|
564
|
+
if model_weight.value.shape != loaded_weight.shape:
|
|
565
|
+
raise ValueError(
|
|
566
|
+
f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
|
|
567
|
+
f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
|
|
568
|
+
)
|
|
569
|
+
logger.debug(
|
|
570
|
+
f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
model_weight.value = shard_put(loaded_weight,
|
|
574
|
+
model_weight.sharding,
|
|
575
|
+
mesh=model_for_loading.mesh)
|
|
576
|
+
if self.is_verbose:
|
|
577
|
+
print_param_info(model_weight, loaded_name)
|
|
578
|
+
|
|
579
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
580
|
+
for buffer_key, expert_map in self.expert_weights_buffer.items(
|
|
581
|
+
):
|
|
582
|
+
sorted_exp_nums = sorted(expert_map.keys())
|
|
583
|
+
aggregated_weight = jnp.stack(
|
|
584
|
+
[expert_map[k] for k in sorted_exp_nums], axis=0)
|
|
585
|
+
is_scale = buffer_key.endswith("_scale")
|
|
586
|
+
base_mapped_name = buffer_key.replace("_scale",
|
|
587
|
+
"").replace(
|
|
588
|
+
"_qvalue", "")
|
|
589
|
+
|
|
590
|
+
model_weight = get_param(model_params, base_mapped_name)
|
|
591
|
+
|
|
592
|
+
assert hasattr(
|
|
593
|
+
model_weight, 'array'
|
|
594
|
+
), f"Expected MoE weight '{base_mapped_name}' to be a quantized array (qarray)"
|
|
595
|
+
|
|
596
|
+
if is_scale:
|
|
597
|
+
loaded_name = f"{base_mapped_name}.array.scale.value"
|
|
598
|
+
if model_weight.array.scale.value.shape != aggregated_weight.shape:
|
|
599
|
+
raise ValueError(
|
|
600
|
+
f"[AGGREGATED] Loaded shape for {buffer_key}: {aggregated_weight.shape}"
|
|
601
|
+
f"does not match model shape for {loaded_name}: {model_weight.array.scale.value.shape}!"
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
model_weight.array.scale.value = shard_put(
|
|
605
|
+
aggregated_weight,
|
|
606
|
+
model_weight.array.scale.sharding,
|
|
607
|
+
mesh=model_for_loading.mesh)
|
|
608
|
+
|
|
609
|
+
elif aggregated_weight.itemsize < 2: # check model weight elem nbits < 16
|
|
610
|
+
loaded_name = f"{base_mapped_name}.array.qvalue.value"
|
|
611
|
+
if model_weight.array.qvalue.value.shape != aggregated_weight.shape:
|
|
612
|
+
raise ValueError(
|
|
613
|
+
f"[AGGREGATED] Loaded shape for {buffer_key}: {aggregated_weight.shape}"
|
|
614
|
+
f"does not match model shape for {loaded_name}: {model_weight.array.qvalue.value.shape}!"
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
model_weight.array.qvalue.value = shard_put(
|
|
618
|
+
aggregated_weight,
|
|
619
|
+
model_weight.array.qvalue.sharding,
|
|
620
|
+
mesh=model_for_loading.mesh)
|
|
621
|
+
|
|
622
|
+
logger.debug(
|
|
623
|
+
f"Aggregated and loaded {loaded_name}: {aggregated_weight.shape}"
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
if self.is_verbose:
|
|
627
|
+
print_param_info(model_weight, loaded_name)
|
|
628
|
+
|
|
629
|
+
nnx.update(model_for_loading, model_params)
|