tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,978 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import re
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import List, Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import jax
|
|
21
|
+
import jax.numpy as jnp
|
|
22
|
+
import torch
|
|
23
|
+
from flax import nnx
|
|
24
|
+
from flax.typing import PRNGKey
|
|
25
|
+
from jax.sharding import Mesh, NamedSharding
|
|
26
|
+
from jax.sharding import PartitionSpec as P
|
|
27
|
+
from torchax.ops.mappings import j2t_dtype
|
|
28
|
+
from vllm.config import VllmConfig
|
|
29
|
+
|
|
30
|
+
from tpu_inference import utils
|
|
31
|
+
from tpu_inference.layers.common.quantization import u8_unpack_e2m1
|
|
32
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
33
|
+
from tpu_inference.layers.jax.attention.attention import AttentionMetadata
|
|
34
|
+
from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
|
|
35
|
+
from tpu_inference.layers.jax.constants import KVCacheType
|
|
36
|
+
from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
|
|
37
|
+
from tpu_inference.layers.jax.moe.deepseek_v3_moe import (DeepSeekV3Router,
|
|
38
|
+
SparseMoE)
|
|
39
|
+
from tpu_inference.layers.jax.moe.moe import MoE
|
|
40
|
+
from tpu_inference.layers.jax.transformer_block import (
|
|
41
|
+
SharedExpertsTransformerBlock, TransformerBlock)
|
|
42
|
+
from tpu_inference.logger import init_logger
|
|
43
|
+
from tpu_inference.models.jax.utils.weight_utils import (
|
|
44
|
+
get_param, model_weights_generator, print_param_info)
|
|
45
|
+
|
|
46
|
+
logger = init_logger(__name__)
|
|
47
|
+
|
|
48
|
+
# A map from JAX dtype to the corresponding PyTorch integer dtype for raw memory viewing.
|
|
49
|
+
DTYPE_VIEW_MAP = {
|
|
50
|
+
jnp.dtype(jnp.float8_e4m3fn): torch.uint8,
|
|
51
|
+
jnp.dtype(jnp.bfloat16): torch.uint16,
|
|
52
|
+
jnp.dtype(jnp.float32): torch.uint32,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class DeepSeekV3(nnx.Module):
|
|
58
|
+
|
|
59
|
+
def __init__(self,
|
|
60
|
+
vllm_config: VllmConfig,
|
|
61
|
+
rng: jax.Array,
|
|
62
|
+
mesh: Mesh,
|
|
63
|
+
force_random_weights: bool = False):
|
|
64
|
+
assert mesh is not None
|
|
65
|
+
|
|
66
|
+
self.vllm_config = vllm_config
|
|
67
|
+
self.rng = nnx.Rngs(rng)
|
|
68
|
+
|
|
69
|
+
# NOTE: the default is 61
|
|
70
|
+
num_layers: int = vllm_config.model_config.hf_config.num_hidden_layers
|
|
71
|
+
num_local_experts: int = 256
|
|
72
|
+
|
|
73
|
+
vocab_size: int = 129280
|
|
74
|
+
hidden_size: int = 7168
|
|
75
|
+
# NOTE: this dtype may be implicitly overriden if using to Qwix to load in the quantized weights
|
|
76
|
+
dtype: jnp.dtype = jnp.bfloat16
|
|
77
|
+
num_attention_heads: int = 128
|
|
78
|
+
num_key_value_heads: int = 128
|
|
79
|
+
ffw_intermediate_size: int = 18432
|
|
80
|
+
moe_intermediate_size: int = 2048
|
|
81
|
+
num_experts_per_token: int = 8
|
|
82
|
+
n_group: int = 8
|
|
83
|
+
interleave_moe_layer_step: int = 1 # Deepseek V3 has moe_layer_freq=1 in hf config.
|
|
84
|
+
hidden_act: str = "silu"
|
|
85
|
+
rms_norm_eps: float = 1e-06
|
|
86
|
+
first_k_dense_replace: int = 3 # replace the first few MOE layers to dense layer.
|
|
87
|
+
self.use_mla_kernel: bool = self.vllm_config.model_config.use_mla
|
|
88
|
+
|
|
89
|
+
logger.info(f"Is using MLA kernel in DeepSeek: {self.use_mla_kernel}")
|
|
90
|
+
|
|
91
|
+
num_shared_experts = 1
|
|
92
|
+
rope_theta = 10000
|
|
93
|
+
rope_scaling = {
|
|
94
|
+
"beta_fast": 32,
|
|
95
|
+
"beta_slow": 1,
|
|
96
|
+
"factor": 40,
|
|
97
|
+
"mscale": 1.0,
|
|
98
|
+
"mscale_all_dim": 1.0,
|
|
99
|
+
"original_max_position_embeddings": 4096,
|
|
100
|
+
"type": "yarn"
|
|
101
|
+
}
|
|
102
|
+
q_lora_rank = 1536
|
|
103
|
+
kv_lora_rank = 512
|
|
104
|
+
qk_nope_head_dim = 128
|
|
105
|
+
qk_rope_head_dim = 64
|
|
106
|
+
v_head_dim = 128
|
|
107
|
+
|
|
108
|
+
self.random_init = force_random_weights or self.vllm_config.additional_config.get(
|
|
109
|
+
"random_weights", False)
|
|
110
|
+
self.sparse_matmul = self.vllm_config.additional_config.get(
|
|
111
|
+
"sparse_matmul", False)
|
|
112
|
+
|
|
113
|
+
if isinstance(self.sparse_matmul, str):
|
|
114
|
+
self.sparse_matmul = self.sparse_matmul.lower() == "true"
|
|
115
|
+
else:
|
|
116
|
+
self.sparse_matmul = bool(self.sparse_matmul)
|
|
117
|
+
|
|
118
|
+
if self.sparse_matmul:
|
|
119
|
+
logger.info("sparse matmul is enabled")
|
|
120
|
+
else:
|
|
121
|
+
logger.info("sparse matmul is disabled, using dense matmul")
|
|
122
|
+
self.mesh = mesh
|
|
123
|
+
|
|
124
|
+
self.weight_loader = DeepSeekV3WeightLoader(
|
|
125
|
+
vllm_config=vllm_config,
|
|
126
|
+
num_layers=num_layers,
|
|
127
|
+
hidden_size=hidden_size,
|
|
128
|
+
q_lora_rank=q_lora_rank,
|
|
129
|
+
kv_lora_rank=kv_lora_rank,
|
|
130
|
+
attn_heads=num_attention_heads,
|
|
131
|
+
qk_nope_head_dim=qk_nope_head_dim,
|
|
132
|
+
qk_rope_head_dim=qk_rope_head_dim,
|
|
133
|
+
v_head_dim=v_head_dim,
|
|
134
|
+
num_local_experts=num_local_experts,
|
|
135
|
+
model_dtype=dtype,
|
|
136
|
+
use_mla_kernel=self.use_mla_kernel)
|
|
137
|
+
|
|
138
|
+
self.embedder = Embedder(vocab_size=vocab_size,
|
|
139
|
+
hidden_size=hidden_size,
|
|
140
|
+
dtype=dtype,
|
|
141
|
+
rngs=self.rng,
|
|
142
|
+
vd_sharding=(ShardingAxisName.MLP_TENSOR,
|
|
143
|
+
None),
|
|
144
|
+
random_init=self.random_init)
|
|
145
|
+
|
|
146
|
+
self.layers = []
|
|
147
|
+
|
|
148
|
+
def _create_mla() -> MLA:
|
|
149
|
+
if self.use_mla_kernel:
|
|
150
|
+
query_tnh_spec = P(ShardingAxisName.MLP_TENSOR, None, None)
|
|
151
|
+
keyvalue_skh_spec = P(ShardingAxisName.MLP_TENSOR, None)
|
|
152
|
+
attn_o_tnh_spec = P(ShardingAxisName.MLP_TENSOR, None, None)
|
|
153
|
+
|
|
154
|
+
else:
|
|
155
|
+
query_tnh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
|
|
156
|
+
keyvalue_skh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
|
|
157
|
+
attn_o_tnh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
|
|
158
|
+
|
|
159
|
+
return MLA(
|
|
160
|
+
rope_theta=rope_theta,
|
|
161
|
+
rope_scaling=rope_scaling,
|
|
162
|
+
q_lora_rank=q_lora_rank,
|
|
163
|
+
kv_lora_rank=kv_lora_rank,
|
|
164
|
+
qk_nope_head_dim=qk_nope_head_dim,
|
|
165
|
+
qk_rope_head_dim=qk_rope_head_dim,
|
|
166
|
+
rms_norm_eps=rms_norm_eps,
|
|
167
|
+
v_head_dim=v_head_dim,
|
|
168
|
+
mesh=self.mesh,
|
|
169
|
+
use_mla_kernel=self.use_mla_kernel,
|
|
170
|
+
random_init=self.random_init,
|
|
171
|
+
hidden_size=hidden_size,
|
|
172
|
+
num_attention_heads=num_attention_heads,
|
|
173
|
+
num_key_value_heads=1
|
|
174
|
+
if self.use_mla_kernel else num_key_value_heads,
|
|
175
|
+
head_dim=v_head_dim, # MLA uses v_head_dim as head_dim
|
|
176
|
+
dtype=dtype,
|
|
177
|
+
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
178
|
+
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
|
179
|
+
rngs=self.rng,
|
|
180
|
+
activation_attention_td=(None, None),
|
|
181
|
+
activation_q_td=(None, None),
|
|
182
|
+
query_tnh=query_tnh_spec,
|
|
183
|
+
keyvalue_skh=keyvalue_skh_spec,
|
|
184
|
+
activation_attention_out_td=(None, None),
|
|
185
|
+
attn_o_tnh=attn_o_tnh_spec,
|
|
186
|
+
q_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
187
|
+
ap_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
188
|
+
anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
|
|
189
|
+
kv_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
190
|
+
rd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
191
|
+
|
|
192
|
+
for i in range(first_k_dense_replace):
|
|
193
|
+
block = TransformerBlock(
|
|
194
|
+
pre_attention_norm=RMSNorm(
|
|
195
|
+
dims=hidden_size,
|
|
196
|
+
random_init=self.random_init,
|
|
197
|
+
epsilon=rms_norm_eps,
|
|
198
|
+
with_scale=True,
|
|
199
|
+
dtype=dtype,
|
|
200
|
+
rngs=self.rng,
|
|
201
|
+
),
|
|
202
|
+
pre_mlp_norm=RMSNorm(
|
|
203
|
+
dims=hidden_size,
|
|
204
|
+
random_init=self.random_init,
|
|
205
|
+
epsilon=rms_norm_eps,
|
|
206
|
+
with_scale=True,
|
|
207
|
+
dtype=dtype,
|
|
208
|
+
rngs=self.rng,
|
|
209
|
+
),
|
|
210
|
+
attn=_create_mla(),
|
|
211
|
+
custom_module=DenseFFW(
|
|
212
|
+
dtype=dtype,
|
|
213
|
+
hidden_act=hidden_act,
|
|
214
|
+
hidden_size=hidden_size,
|
|
215
|
+
intermediate_size=ffw_intermediate_size,
|
|
216
|
+
rngs=self.rng,
|
|
217
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
218
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None),
|
|
219
|
+
random_init=self.random_init))
|
|
220
|
+
|
|
221
|
+
self.layers.append(block)
|
|
222
|
+
|
|
223
|
+
for i in range(first_k_dense_replace, num_layers):
|
|
224
|
+
is_moe_layer = ((i + 1) % interleave_moe_layer_step == 0)
|
|
225
|
+
router = DeepSeekV3Router(
|
|
226
|
+
random_init=self.random_init,
|
|
227
|
+
hidden_size=hidden_size,
|
|
228
|
+
num_experts=num_local_experts,
|
|
229
|
+
num_experts_per_tok=num_experts_per_token,
|
|
230
|
+
n_groups=n_group,
|
|
231
|
+
topk_groups=4,
|
|
232
|
+
norm_topk_prob=True,
|
|
233
|
+
rngs=self.rng,
|
|
234
|
+
routed_scaling_factor=2.5,
|
|
235
|
+
dtype=dtype,
|
|
236
|
+
activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
|
|
237
|
+
ed_sharding=(ShardingAxisName.MLP_TENSOR, None),
|
|
238
|
+
e_sharding=(ShardingAxisName.MLP_TENSOR, ))
|
|
239
|
+
if self.sparse_matmul:
|
|
240
|
+
# TODO: orginize the SparseMoE and DenseMoE better given they share most interfaces
|
|
241
|
+
custom_module = SparseMoE(
|
|
242
|
+
dtype=dtype,
|
|
243
|
+
num_local_experts=num_local_experts,
|
|
244
|
+
apply_expert_weight_before_computation=False,
|
|
245
|
+
hidden_size=hidden_size,
|
|
246
|
+
intermediate_size_moe=moe_intermediate_size,
|
|
247
|
+
num_experts_per_tok=num_experts_per_token,
|
|
248
|
+
mesh=self.mesh,
|
|
249
|
+
hidden_act=hidden_act,
|
|
250
|
+
rngs=self.rng,
|
|
251
|
+
random_init=self.random_init,
|
|
252
|
+
activation_ffw_td=(ShardingAxisName.MLP_TENSOR, None),
|
|
253
|
+
activation_ffw_ted=(ShardingAxisName.MLP_DATA, None, None),
|
|
254
|
+
edf_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
255
|
+
efd_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
256
|
+
quantized_dtype=self.weight_loader.quant_dtype
|
|
257
|
+
if self.weight_loader.is_model_quantized else None,
|
|
258
|
+
router=router) if is_moe_layer else DenseFFW(
|
|
259
|
+
dtype=dtype,
|
|
260
|
+
hidden_act=hidden_act,
|
|
261
|
+
hidden_size=hidden_size,
|
|
262
|
+
intermediate_size=ffw_intermediate_size,
|
|
263
|
+
rngs=self.rng,
|
|
264
|
+
random_init=self.random_init,
|
|
265
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
266
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
267
|
+
else:
|
|
268
|
+
custom_module = MoE(
|
|
269
|
+
dtype=dtype,
|
|
270
|
+
num_local_experts=num_local_experts,
|
|
271
|
+
apply_expert_weight_before_computation=False,
|
|
272
|
+
hidden_size=hidden_size,
|
|
273
|
+
intermediate_size_moe=moe_intermediate_size,
|
|
274
|
+
hidden_act=hidden_act,
|
|
275
|
+
rngs=self.rng,
|
|
276
|
+
random_init=self.random_init,
|
|
277
|
+
activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
|
|
278
|
+
activation_ffw_ted=(ShardingAxisName.MLP_DATA, None, None),
|
|
279
|
+
edf_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
280
|
+
efd_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
281
|
+
router=router) if is_moe_layer else DenseFFW(
|
|
282
|
+
dtype=dtype,
|
|
283
|
+
hidden_act=hidden_act,
|
|
284
|
+
hidden_size=hidden_size,
|
|
285
|
+
intermediate_size=ffw_intermediate_size,
|
|
286
|
+
rngs=self.rng,
|
|
287
|
+
random_init=self.random_init,
|
|
288
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
289
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
290
|
+
|
|
291
|
+
shared_experts = DenseFFW(
|
|
292
|
+
dtype=dtype,
|
|
293
|
+
hidden_act=hidden_act,
|
|
294
|
+
hidden_size=hidden_size,
|
|
295
|
+
intermediate_size=num_shared_experts * moe_intermediate_size,
|
|
296
|
+
rngs=self.rng,
|
|
297
|
+
random_init=self.random_init,
|
|
298
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
299
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
300
|
+
|
|
301
|
+
pre_attention_norm = RMSNorm(
|
|
302
|
+
dims=hidden_size,
|
|
303
|
+
rngs=self.rng,
|
|
304
|
+
random_init=self.random_init,
|
|
305
|
+
epsilon=rms_norm_eps,
|
|
306
|
+
with_scale=True,
|
|
307
|
+
dtype=dtype,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
pre_mlp_norm = RMSNorm(
|
|
311
|
+
dims=hidden_size,
|
|
312
|
+
rngs=self.rng,
|
|
313
|
+
random_init=self.random_init,
|
|
314
|
+
epsilon=rms_norm_eps,
|
|
315
|
+
with_scale=True,
|
|
316
|
+
dtype=dtype,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
block = SharedExpertsTransformerBlock(
|
|
320
|
+
custom_module=custom_module,
|
|
321
|
+
attn=_create_mla(),
|
|
322
|
+
pre_attention_norm=pre_attention_norm,
|
|
323
|
+
pre_mlp_norm=pre_mlp_norm,
|
|
324
|
+
shared_experts=shared_experts)
|
|
325
|
+
self.layers.append(block)
|
|
326
|
+
|
|
327
|
+
self.final_norm = RMSNorm(
|
|
328
|
+
dims=hidden_size,
|
|
329
|
+
rngs=self.rng,
|
|
330
|
+
random_init=self.random_init,
|
|
331
|
+
epsilon=rms_norm_eps,
|
|
332
|
+
with_scale=True,
|
|
333
|
+
dtype=dtype,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
self.lm_head = LMhead(vocab_size=vocab_size,
|
|
337
|
+
hidden_size=hidden_size,
|
|
338
|
+
dtype=dtype,
|
|
339
|
+
rngs=self.rng,
|
|
340
|
+
vd_sharding=(ShardingAxisName.MLP_TENSOR, None),
|
|
341
|
+
dv_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
342
|
+
random_init=self.random_init)
|
|
343
|
+
|
|
344
|
+
if os.environ.get("VLLM_LOGGING_LEVEL", "").upper() == "DEBUG":
|
|
345
|
+
self._print_model_architecture()
|
|
346
|
+
|
|
347
|
+
def _print_model_architecture(self):
|
|
348
|
+
num_display_layers = 5
|
|
349
|
+
|
|
350
|
+
logger.debug("### Embedding ###")
|
|
351
|
+
nnx.display(self.embedder)
|
|
352
|
+
|
|
353
|
+
logger.debug(f"\n### First {num_display_layers} Layers ###")
|
|
354
|
+
# Loop through the slice and display each layer
|
|
355
|
+
for i, layer in enumerate(self.layers[:num_display_layers]):
|
|
356
|
+
logger.debug(f"\n--- Layer {i} ---")
|
|
357
|
+
nnx.display(layer)
|
|
358
|
+
|
|
359
|
+
logger.debug("\n### LM Head ###")
|
|
360
|
+
nnx.display(self.lm_head)
|
|
361
|
+
|
|
362
|
+
# For compatibility with flax.
|
|
363
|
+
def apply(self, variables, *args, **kwargs):
|
|
364
|
+
return self.__call__(*args, **kwargs)
|
|
365
|
+
|
|
366
|
+
def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
|
|
367
|
+
# NOTE: Since we are using nnx.eval_shape to init the model,
|
|
368
|
+
# we have to pass dynamic arrays here for __call__'s usage.
|
|
369
|
+
self.rng = nnx.Rngs(rng)
|
|
370
|
+
self.weight_loader.load_weights(self)
|
|
371
|
+
self.initialize_cache()
|
|
372
|
+
|
|
373
|
+
def initialize_cache(self):
|
|
374
|
+
# Initialize RoPE caches after weights are loaded and before JIT compilation.
|
|
375
|
+
for layer in self.layers:
|
|
376
|
+
if hasattr(layer, 'attn') and hasattr(layer.attn, 'rope'):
|
|
377
|
+
if hasattr(layer.attn.rope, 'initialize_cache'):
|
|
378
|
+
layer.attn.rope.initialize_cache(self.mesh)
|
|
379
|
+
|
|
380
|
+
def __call__(
|
|
381
|
+
self,
|
|
382
|
+
kv_caches: List[jax.Array],
|
|
383
|
+
input_ids: jax.Array,
|
|
384
|
+
attention_metadata: AttentionMetadata,
|
|
385
|
+
*args,
|
|
386
|
+
) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]:
|
|
387
|
+
is_prefill = False
|
|
388
|
+
x = self.embedder.encode(input_ids)
|
|
389
|
+
for (i, block) in enumerate(self.layers):
|
|
390
|
+
kv_cache = kv_caches[i]
|
|
391
|
+
new_kv_cache, x = block(x, is_prefill, kv_cache,
|
|
392
|
+
attention_metadata)
|
|
393
|
+
kv_caches[i] = new_kv_cache
|
|
394
|
+
|
|
395
|
+
final_activation = self.final_norm(x)
|
|
396
|
+
|
|
397
|
+
return kv_caches, final_activation, []
|
|
398
|
+
|
|
399
|
+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
400
|
+
return self.lm_head.decode(hidden_states)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
@dataclass
|
|
404
|
+
class DeepSeekV3WeightLoader:
|
|
405
|
+
|
|
406
|
+
def __init__(self,
|
|
407
|
+
vllm_config: VllmConfig,
|
|
408
|
+
num_layers,
|
|
409
|
+
hidden_size,
|
|
410
|
+
q_lora_rank,
|
|
411
|
+
kv_lora_rank,
|
|
412
|
+
attn_heads,
|
|
413
|
+
qk_nope_head_dim,
|
|
414
|
+
qk_rope_head_dim,
|
|
415
|
+
v_head_dim,
|
|
416
|
+
num_local_experts,
|
|
417
|
+
model_dtype,
|
|
418
|
+
use_mla_kernel=False):
|
|
419
|
+
self.num_layers = num_layers
|
|
420
|
+
self.names_and_weights_generator = model_weights_generator(
|
|
421
|
+
model_name_or_path=vllm_config.model_config.model,
|
|
422
|
+
framework="pt",
|
|
423
|
+
download_dir=vllm_config.load_config.download_dir)
|
|
424
|
+
self.is_verbose = vllm_config.additional_config.get(
|
|
425
|
+
"is_verbose", None) is not None
|
|
426
|
+
self.num_routed_experts = num_local_experts
|
|
427
|
+
self.attn_heads = attn_heads
|
|
428
|
+
self.qk_nope_head_dim = qk_nope_head_dim
|
|
429
|
+
self.v_head_dim = v_head_dim
|
|
430
|
+
self.kv_lora_rank = kv_lora_rank
|
|
431
|
+
self.model_dtype = model_dtype
|
|
432
|
+
self.use_mla_kernel = use_mla_kernel
|
|
433
|
+
|
|
434
|
+
self._transpose_map = {
|
|
435
|
+
# dense mlp
|
|
436
|
+
r"mlp\.down_proj": (1, 0),
|
|
437
|
+
r"mlp\.gate_proj": (1, 0),
|
|
438
|
+
r"mlp\.up_proj": (1, 0),
|
|
439
|
+
# mla
|
|
440
|
+
r"q_a_proj": (1, 0),
|
|
441
|
+
r"q_b_proj": (1, 0),
|
|
442
|
+
r"kv_a_proj_with_mqa": (1, 0),
|
|
443
|
+
r"kv_b_proj": (1, 0),
|
|
444
|
+
r"k_b_proj": (2, 0, 1), # used for MLA kernel
|
|
445
|
+
r"v_b_proj": (2, 0, 1), # used for MLA kernel
|
|
446
|
+
r"o_proj": (1, 0),
|
|
447
|
+
# moe
|
|
448
|
+
r"mlp\.gate\.weight": (1, 0),
|
|
449
|
+
r"mlp\.experts\.\d+\.gate_proj": (0, 2, 1),
|
|
450
|
+
r"mlp\.experts\.\d+\.down_proj": (0, 2, 1),
|
|
451
|
+
r"mlp\.experts\.\d+\.up_proj": (0, 2, 1),
|
|
452
|
+
r"mlp\.shared_experts\.down_proj": (1, 0),
|
|
453
|
+
r"mlp\.shared_experts\.gate_proj": (1, 0),
|
|
454
|
+
r"mlp\.shared_experts\.up_proj": (1, 0),
|
|
455
|
+
# lm_head
|
|
456
|
+
r"lm_head\.weight": (1, 0)
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
# Set the mappings from loaded parameter keys to standardized names.
|
|
460
|
+
self._loaded_to_standardized_keys = {
|
|
461
|
+
# encode & decode
|
|
462
|
+
"model.embed_tokens.weight":
|
|
463
|
+
"embedder.input_embedding_table_VD",
|
|
464
|
+
"lm_head.weight":
|
|
465
|
+
"lm_head.input_embedding_table_DV",
|
|
466
|
+
# final norm
|
|
467
|
+
"model.norm.weight":
|
|
468
|
+
"final_norm.scale",
|
|
469
|
+
# norm in transformer blocks
|
|
470
|
+
"model.layers.*.input_layernorm.weight":
|
|
471
|
+
"layers.*.pre_attention_norm.scale",
|
|
472
|
+
"model.layers.*.post_attention_layernorm.weight":
|
|
473
|
+
"layers.*.pre_mlp_norm.scale",
|
|
474
|
+
# attention (MLA)
|
|
475
|
+
"model.layers.*.self_attn.q_a_layernorm.weight":
|
|
476
|
+
"layers.*.attn.q_rms_norm.scale",
|
|
477
|
+
"model.layers.*.self_attn.kv_a_layernorm.weight":
|
|
478
|
+
"layers.*.attn.kv_rms_norm.scale",
|
|
479
|
+
"model.layers.*.self_attn.q_a_proj.weight":
|
|
480
|
+
"layers.*.attn.kernel_q_down_proj_DA",
|
|
481
|
+
"model.layers.*.self_attn.q_b_proj.weight":
|
|
482
|
+
"layers.*.attn.kernel_q_up_proj_AP",
|
|
483
|
+
"model.layers.*.self_attn.kv_a_proj_with_mqa.weight":
|
|
484
|
+
"layers.*.attn.kernel_kv_down_proj_DA",
|
|
485
|
+
"model.layers.*.self_attn.kv_b_proj.weight":
|
|
486
|
+
"layers.*.attn.kernel_kv_up_proj_AL",
|
|
487
|
+
"model.layers.*.self_attn.o_proj.weight":
|
|
488
|
+
"layers.*.attn.kernel_o_proj_RD",
|
|
489
|
+
# Dense ffw
|
|
490
|
+
"model.layers.*.mlp.gate_proj.weight":
|
|
491
|
+
"layers.*.custom_module.kernel_gating_DF",
|
|
492
|
+
"model.layers.*.mlp.up_proj.weight":
|
|
493
|
+
"layers.*.custom_module.kernel_up_proj_DF",
|
|
494
|
+
"model.layers.*.mlp.down_proj.weight":
|
|
495
|
+
"layers.*.custom_module.kernel_down_proj_FD",
|
|
496
|
+
# MOE(routed experts)
|
|
497
|
+
"model.layers.*.mlp.gate.weight":
|
|
498
|
+
"layers.*.custom_module.router.kernel_DE",
|
|
499
|
+
"model.layers.*.mlp.gate.e_score_correction_bias":
|
|
500
|
+
"layers.*.custom_module.router.bias_E",
|
|
501
|
+
"model.layers.*.mlp.experts.*.gate_proj.weight":
|
|
502
|
+
"layers.*.custom_module.kernel_gating_EDF",
|
|
503
|
+
"model.layers.*.mlp.experts.*.down_proj.weight":
|
|
504
|
+
"layers.*.custom_module.kernel_down_proj_EFD",
|
|
505
|
+
"model.layers.*.mlp.experts.*.up_proj.weight":
|
|
506
|
+
"layers.*.custom_module.kernel_up_proj_EDF",
|
|
507
|
+
# MOE(shared experts)
|
|
508
|
+
"model.layers.*.mlp.shared_experts.down_proj.weight":
|
|
509
|
+
"layers.*.shared_experts.kernel_down_proj_FD",
|
|
510
|
+
"model.layers.*.mlp.shared_experts.gate_proj.weight":
|
|
511
|
+
"layers.*.shared_experts.kernel_gating_DF",
|
|
512
|
+
"model.layers.*.mlp.shared_experts.up_proj.weight":
|
|
513
|
+
"layers.*.shared_experts.kernel_up_proj_DF",
|
|
514
|
+
}
|
|
515
|
+
if self.use_mla_kernel:
|
|
516
|
+
self._loaded_to_standardized_keys.update({
|
|
517
|
+
"model.layers.*.self_attn.k_b_proj.weight":
|
|
518
|
+
"layers.*.attn.kernel_k_up_proj_ANH",
|
|
519
|
+
"model.layers.*.self_attn.v_b_proj.weight":
|
|
520
|
+
"layers.*.attn.kernel_v_up_proj_ANH",
|
|
521
|
+
})
|
|
522
|
+
# TODO (jacobplatin): we should not be hard-coding these
|
|
523
|
+
self.scale_dtype, self.quant_dtype = jnp.bfloat16, jnp.float8_e4m3fn
|
|
524
|
+
|
|
525
|
+
self.is_model_quantized = not vllm_config.additional_config.get(
|
|
526
|
+
"skip_quantization", False)
|
|
527
|
+
|
|
528
|
+
if self.is_model_quantized:
|
|
529
|
+
# NOTE: this is only needed for pre-quantized models when doing random weight loading
|
|
530
|
+
# because the scales that Qwix configures by default don't necessarily match the
|
|
531
|
+
# scales in practice
|
|
532
|
+
# TODO (jacobplatin): remove or clean this up
|
|
533
|
+
self.scale_shape_map_for_random_weight_loading = {
|
|
534
|
+
# MoE experts (3D)
|
|
535
|
+
"custom_module.kernel_down_proj_EFD": (256, 8, 7168),
|
|
536
|
+
"custom_module.kernel_gating_EDF": (256, 28, 2048),
|
|
537
|
+
"custom_module.kernel_up_proj_EDF": (256, 28, 2048),
|
|
538
|
+
# Shared experts (2D)
|
|
539
|
+
"shared_experts.kernel_down_proj_FD": (8, 7168),
|
|
540
|
+
"shared_experts.kernel_gating_DF": (28, 2048),
|
|
541
|
+
"shared_experts.kernel_up_proj_DF": (28, 2048),
|
|
542
|
+
# Dense FFW (2D)
|
|
543
|
+
"custom_module.kernel_gating_DF": (28, 18432),
|
|
544
|
+
"custom_module.kernel_up_proj_DF": (28, 18432),
|
|
545
|
+
"custom_module.kernel_down_proj_FD": (72, 7168),
|
|
546
|
+
# Attention (3D for MLA, 2D for the rest)
|
|
547
|
+
"attn.kernel_q_down_proj_DA": (28, 1536),
|
|
548
|
+
"attn.kernel_q_up_proj_AP": (6, 24576),
|
|
549
|
+
"attn.kernel_kv_down_proj_DA": (28, 576),
|
|
550
|
+
"attn.kernel_kv_up_proj_AL": (2, 32768),
|
|
551
|
+
"attn.kernel_o_proj_RD": (64, 7168),
|
|
552
|
+
"attn.kernel_k_up_proj_ANH": (2, 128, 128), # MLA
|
|
553
|
+
"attn.kernel_v_up_proj_ANH": (2, 128, 128), # MLA
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
# TODO (jacobplatin): remove this check eventually!
|
|
557
|
+
assert self.quant_dtype == jnp.float8_e4m3fn, f"Expected quant_dtype to be float8_e4m3fn for DeepSeek but got {self.quant_dtype}"
|
|
558
|
+
|
|
559
|
+
def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
|
|
560
|
+
# Find the corresponding model key using the HF key
|
|
561
|
+
if "layer" in loaded_key:
|
|
562
|
+
# extract layer number and replace it with *
|
|
563
|
+
layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
|
|
564
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
|
|
565
|
+
# extract expert number if exists and replace it with *
|
|
566
|
+
if "experts" in loaded_key and "shared_experts" not in loaded_key:
|
|
567
|
+
layer_key = re.sub(r"experts\.\d+", "experts.*", layer_key)
|
|
568
|
+
# get standardized key and replace * with layer number.
|
|
569
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
570
|
+
layer_key, loaded_key)
|
|
571
|
+
mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
|
|
572
|
+
mapped_key)
|
|
573
|
+
else:
|
|
574
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
575
|
+
loaded_key, loaded_key)
|
|
576
|
+
return mapped_key
|
|
577
|
+
|
|
578
|
+
def _transpose_params(self, param_key: str, param_tensor: jax.Array):
|
|
579
|
+
for key, value in self._transpose_map.items():
|
|
580
|
+
if re.search(key, param_key):
|
|
581
|
+
return jnp.transpose(param_tensor, value)
|
|
582
|
+
return param_tensor # Base case / no-op
|
|
583
|
+
|
|
584
|
+
def _process_moe_weights(self, loaded_name, loaded_weight, weights_dict):
|
|
585
|
+
layer_num = re.search(r"layers\.(\d+)", loaded_name).group(1)
|
|
586
|
+
expert_num_str = re.search(r"experts\.(\d+)", loaded_name).group(1)
|
|
587
|
+
expert_idx = int(expert_num_str)
|
|
588
|
+
|
|
589
|
+
if layer_num not in weights_dict:
|
|
590
|
+
weights_dict[layer_num] = ([None] * self.num_routed_experts, 0)
|
|
591
|
+
|
|
592
|
+
expert_list, count = weights_dict[layer_num]
|
|
593
|
+
|
|
594
|
+
expert_list[expert_idx] = loaded_weight
|
|
595
|
+
count += 1
|
|
596
|
+
weights_dict[layer_num] = (expert_list, count)
|
|
597
|
+
|
|
598
|
+
if count == self.num_routed_experts:
|
|
599
|
+
stacked_weights = torch.stack(expert_list, axis=0)
|
|
600
|
+
del weights_dict[layer_num]
|
|
601
|
+
return stacked_weights
|
|
602
|
+
return None
|
|
603
|
+
|
|
604
|
+
def _load_individual_weight(self,
|
|
605
|
+
name,
|
|
606
|
+
weight,
|
|
607
|
+
model_params,
|
|
608
|
+
model_mesh,
|
|
609
|
+
scale=None) -> Tuple[int, int]:
|
|
610
|
+
"""
|
|
611
|
+
Loads a single weight into the model.
|
|
612
|
+
|
|
613
|
+
NOTE: if using the base quantized model, it is assumed that the Qwix abstract
|
|
614
|
+
pass has been run and that the model weights are thus QArrays, which we
|
|
615
|
+
will then load the weights/scales into.
|
|
616
|
+
|
|
617
|
+
Args:
|
|
618
|
+
name: The name of the weight.
|
|
619
|
+
weight: The weight to load.
|
|
620
|
+
model_params: The model parameters.
|
|
621
|
+
model_mesh: The model mesh.
|
|
622
|
+
scale: The scale of the weight (if using the pre-quantized model).
|
|
623
|
+
|
|
624
|
+
Returns:
|
|
625
|
+
Tuple[int, int]: The size (in bytes) for the given layer overall and per shard.
|
|
626
|
+
NOTE: if using the pre-quantized model (with Qwix), we'll include the scale size as well.
|
|
627
|
+
"""
|
|
628
|
+
mapped_name = self.map_loaded_to_standardized_name(name)
|
|
629
|
+
base_model_weight = get_param(model_params, mapped_name)
|
|
630
|
+
model_weight = base_model_weight.array.qvalue if hasattr(
|
|
631
|
+
base_model_weight, "array") else base_model_weight
|
|
632
|
+
sharding = base_model_weight.array.qvalue.sharding if hasattr(
|
|
633
|
+
base_model_weight, "array") else base_model_weight.sharding
|
|
634
|
+
|
|
635
|
+
# Convert weights from torch into numpy
|
|
636
|
+
if weight.dtype == torch.uint8 and scale is not None:
|
|
637
|
+
# Assume packed FP4 format when uint8 weights with scale provided
|
|
638
|
+
weight_jax_u8 = jnp.array(weight.cpu().numpy())
|
|
639
|
+
weight_np = u8_unpack_e2m1(weight_jax_u8)
|
|
640
|
+
scale = scale.to(torch.float32).numpy().astype(self.scale_dtype)
|
|
641
|
+
else:
|
|
642
|
+
cast_type = model_weight.value.dtype
|
|
643
|
+
# Special-case: FP4 values stored as FP8 for compatibility.
|
|
644
|
+
# If the model expects float4_e2m1fn but the checkpoint provides FP8,
|
|
645
|
+
# convert by numeric value (float32) then cast to float4.
|
|
646
|
+
if cast_type == jnp.float4_e2m1fn and weight.dtype == torch.float8_e4m3fn:
|
|
647
|
+
weight_np = jnp.array(weight.float().numpy()).astype(cast_type)
|
|
648
|
+
else:
|
|
649
|
+
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
|
|
650
|
+
|
|
651
|
+
if torch_view_type:
|
|
652
|
+
# Avoid unnecessary upcasting and mem copy by viewing the tensor's
|
|
653
|
+
# raw data as integers before converting to a JAX array.
|
|
654
|
+
weight_np = jnp.array(
|
|
655
|
+
weight.view(torch_view_type).numpy()).view(cast_type)
|
|
656
|
+
else:
|
|
657
|
+
raise ValueError(
|
|
658
|
+
f"Unsupported dtype for tensor conversion: {cast_type}"
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
if scale is not None:
|
|
662
|
+
scale = scale.to(torch.float32).numpy().astype(
|
|
663
|
+
self.scale_dtype)
|
|
664
|
+
weight_np = self._transpose_params(name, weight_np)
|
|
665
|
+
if scale is not None:
|
|
666
|
+
scale = self._transpose_params(name, scale)
|
|
667
|
+
# Ensure scale is broadcastable to weight_np by repeating per-axis.
|
|
668
|
+
weight_shape = weight_np.shape
|
|
669
|
+
scale_shape = scale.shape
|
|
670
|
+
if len(weight_shape) == len(scale_shape):
|
|
671
|
+
new_scale = scale
|
|
672
|
+
for wdim, sdim in zip(weight_shape, scale_shape):
|
|
673
|
+
if (wdim % sdim != 0):
|
|
674
|
+
raise ValueError(
|
|
675
|
+
f"Weight dim {wdim} is not divisible by scale dim {sdim} for weight {name} with shape {weight_shape} and scale {scale_shape}!"
|
|
676
|
+
)
|
|
677
|
+
if scale_shape != new_scale.shape:
|
|
678
|
+
logger.warning(
|
|
679
|
+
f"Adjusted scale shape {scale_shape} to {new_scale.shape} to match weight {weight_shape}"
|
|
680
|
+
)
|
|
681
|
+
scale = new_scale
|
|
682
|
+
else:
|
|
683
|
+
raise ValueError(
|
|
684
|
+
f"Scale rank {scale_shape} does not match weight rank {weight_shape}"
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
if model_weight.value.shape != weight_np.shape:
|
|
688
|
+
raise ValueError(
|
|
689
|
+
f"Loaded shape for {name}: {weight_np.shape} "
|
|
690
|
+
f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
def get_slice(index):
|
|
694
|
+
return weight_np[index]
|
|
695
|
+
|
|
696
|
+
def get_slice_scale(index):
|
|
697
|
+
# ruff: noqa: F821
|
|
698
|
+
return scale[index]
|
|
699
|
+
|
|
700
|
+
sharded_array = jax.make_array_from_callback(
|
|
701
|
+
weight_np.shape, NamedSharding(model_mesh, P(*sharding)),
|
|
702
|
+
get_slice)
|
|
703
|
+
|
|
704
|
+
if scale is not None:
|
|
705
|
+
maybe_sharded_scale = scale
|
|
706
|
+
# Since, by default, we'll use the same sharding as the weights, we might
|
|
707
|
+
# encounter an error where the smaller/different sharding dim isn't divisible
|
|
708
|
+
# by the parallel size
|
|
709
|
+
# NOTE: Qwix expert dangyi@ mentioned that we don't need to worry about accuracy
|
|
710
|
+
# impacts when sharing scales
|
|
711
|
+
try:
|
|
712
|
+
maybe_sharded_scale = jax.make_array_from_callback(
|
|
713
|
+
scale.shape, NamedSharding(model_mesh, P(*sharding)),
|
|
714
|
+
get_slice_scale)
|
|
715
|
+
except ValueError:
|
|
716
|
+
logger.warning(
|
|
717
|
+
f"Could not create sharded scale for {name} with shape {scale.shape} and sharding {sharding}, skipping sharding..."
|
|
718
|
+
)
|
|
719
|
+
assert base_model_weight.array.scale.value.dtype == maybe_sharded_scale.dtype, f"Expected dtype for model weight scale with name {mapped_name} and dtype ({base_model_weight.array.scale.value.dtype}) to match that of the incoming weight scale ({maybe_sharded_scale.dtype})"
|
|
720
|
+
assert base_model_weight.array.qvalue.value.dtype == sharded_array.dtype, f"Expected dtype for model weight with name {mapped_name} and dtype ({base_model_weight.array.qvalue.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
|
|
721
|
+
base_model_weight.array.scale.value = maybe_sharded_scale
|
|
722
|
+
base_model_weight.array.qvalue.value = sharded_array
|
|
723
|
+
else:
|
|
724
|
+
assert model_weight.value.dtype == sharded_array.dtype, f"Expected dtype for model weight with name {mapped_name} and dtype ({model_weight.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
|
|
725
|
+
model_weight.value = sharded_array
|
|
726
|
+
|
|
727
|
+
model_weight_size_bytes = model_weight.nbytes / 1e9
|
|
728
|
+
model_weight_local_size_bytes = model_weight.addressable_shards[
|
|
729
|
+
0].data.nbytes / 1e9
|
|
730
|
+
|
|
731
|
+
if scale is not None:
|
|
732
|
+
model_weight_size_bytes += base_model_weight.array.scale.nbytes / 1e9
|
|
733
|
+
model_weight_local_size_bytes += base_model_weight.array.scale.addressable_shards[
|
|
734
|
+
0].data.nbytes / 1e9
|
|
735
|
+
|
|
736
|
+
if self.is_verbose:
|
|
737
|
+
logger.info(f"Memory usage after loading in {name}: "
|
|
738
|
+
f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
|
|
739
|
+
print_param_info(model_weight, name)
|
|
740
|
+
if scale is not None:
|
|
741
|
+
print_param_info(base_model_weight.array.scale,
|
|
742
|
+
"scale for " + name)
|
|
743
|
+
|
|
744
|
+
del weight, scale
|
|
745
|
+
return model_weight_size_bytes, model_weight_local_size_bytes
|
|
746
|
+
|
|
747
|
+
def load_weights(self, model_for_loading: nnx.Module):
|
|
748
|
+
model_params = nnx.state(model_for_loading)
|
|
749
|
+
logger.warning(
|
|
750
|
+
f"loaded_to_standardized_keys: {self._loaded_to_standardized_keys}"
|
|
751
|
+
)
|
|
752
|
+
cumulative_global_memory = 0
|
|
753
|
+
cumulative_local_memory = 0
|
|
754
|
+
mlp_experts_gate_proj_weights = {}
|
|
755
|
+
mlp_experts_gate_proj_scales = {}
|
|
756
|
+
mlp_experts_up_proj_weights = {}
|
|
757
|
+
mlp_experts_up_proj_scales = {}
|
|
758
|
+
mlp_experts_down_proj_weights = {}
|
|
759
|
+
mlp_experts_down_proj_scales = {}
|
|
760
|
+
quantized_weights = {}
|
|
761
|
+
quantized_scales = {}
|
|
762
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
763
|
+
for loaded_name, loaded_weight in self.names_and_weights_generator:
|
|
764
|
+
# Skip if the model has fewer layers than original.
|
|
765
|
+
if re.search(r"layers\.(\d+)", loaded_name):
|
|
766
|
+
layer_num = re.search(r"layers\.(\d+)",
|
|
767
|
+
loaded_name).group(1)
|
|
768
|
+
if int(layer_num) >= self.num_layers:
|
|
769
|
+
del loaded_weight
|
|
770
|
+
continue
|
|
771
|
+
if 'layers.61' in loaded_name:
|
|
772
|
+
# skip loading MTP module.
|
|
773
|
+
del loaded_weight
|
|
774
|
+
continue
|
|
775
|
+
if re.search(r"experts\.(\d+)", loaded_name):
|
|
776
|
+
expert_num = re.search(r"experts\.(\d+)",
|
|
777
|
+
loaded_name).group(1)
|
|
778
|
+
if int(expert_num) >= self.num_routed_experts:
|
|
779
|
+
del loaded_weight
|
|
780
|
+
continue
|
|
781
|
+
# NOTE: loaded_weight will be a Torch tensor, so we need to convert it to the
|
|
782
|
+
# equivalent jnp dtype
|
|
783
|
+
# TODO (jacobplatin): refactor this so that we instead change / update `model_weights_generator`
|
|
784
|
+
# instead of checking "weight_scale_inv" and assuming quantization method is fp8
|
|
785
|
+
scale = None
|
|
786
|
+
# Mixed quantization: accept both fp8 and packed fp4 (uint8) tensors
|
|
787
|
+
allowed_quant_dtypes = {
|
|
788
|
+
j2t_dtype(self.quant_dtype.dtype), torch.uint8
|
|
789
|
+
}
|
|
790
|
+
if loaded_weight.dtype in allowed_quant_dtypes:
|
|
791
|
+
if self.is_model_quantized:
|
|
792
|
+
scale_name = loaded_name.replace(
|
|
793
|
+
".weight", ".weight_scale_inv")
|
|
794
|
+
if scale_name in quantized_scales:
|
|
795
|
+
scale = quantized_scales[scale_name]
|
|
796
|
+
del quantized_scales[scale_name]
|
|
797
|
+
else:
|
|
798
|
+
quantized_weights[loaded_name] = loaded_weight
|
|
799
|
+
continue
|
|
800
|
+
else:
|
|
801
|
+
quantized_weights[loaded_name] = loaded_weight
|
|
802
|
+
continue
|
|
803
|
+
|
|
804
|
+
if loaded_name.endswith(".weight_scale_inv"):
|
|
805
|
+
if self.is_model_quantized:
|
|
806
|
+
weight_name = loaded_name.replace(
|
|
807
|
+
".weight_scale_inv", ".weight")
|
|
808
|
+
if weight_name in quantized_weights:
|
|
809
|
+
scale = loaded_weight
|
|
810
|
+
loaded_weight = quantized_weights[weight_name]
|
|
811
|
+
loaded_name = weight_name
|
|
812
|
+
del quantized_weights[weight_name]
|
|
813
|
+
else:
|
|
814
|
+
quantized_scales[loaded_name] = loaded_weight
|
|
815
|
+
continue
|
|
816
|
+
# In the case that we don't want to use the quantized weights,
|
|
817
|
+
# we'll dequantize the weights using the loaded scale on-the-fly
|
|
818
|
+
else:
|
|
819
|
+
# assuming weights are loaded before scales.
|
|
820
|
+
weight_name = loaded_name.replace(
|
|
821
|
+
".weight_scale_inv", ".weight")
|
|
822
|
+
loaded_weight = weights_dequant_cpu(
|
|
823
|
+
quantized_weights[weight_name], loaded_weight,
|
|
824
|
+
self.model_dtype)
|
|
825
|
+
loaded_name = weight_name
|
|
826
|
+
del quantized_weights[weight_name]
|
|
827
|
+
# concat mlp.experts weights
|
|
828
|
+
stacked_scales = None
|
|
829
|
+
stacked_weights = None
|
|
830
|
+
if "mlp.experts" in loaded_name:
|
|
831
|
+
if "down_proj" in loaded_name:
|
|
832
|
+
stacked_weights = self._process_moe_weights(
|
|
833
|
+
loaded_name, loaded_weight,
|
|
834
|
+
mlp_experts_down_proj_weights)
|
|
835
|
+
if scale is not None:
|
|
836
|
+
stacked_scales = self._process_moe_weights(
|
|
837
|
+
loaded_name, scale,
|
|
838
|
+
mlp_experts_down_proj_scales)
|
|
839
|
+
if "gate_proj" in loaded_name:
|
|
840
|
+
stacked_weights = self._process_moe_weights(
|
|
841
|
+
loaded_name, loaded_weight,
|
|
842
|
+
mlp_experts_gate_proj_weights)
|
|
843
|
+
if scale is not None:
|
|
844
|
+
stacked_scales = self._process_moe_weights(
|
|
845
|
+
loaded_name, scale,
|
|
846
|
+
mlp_experts_gate_proj_scales)
|
|
847
|
+
if "up_proj" in loaded_name:
|
|
848
|
+
stacked_weights = self._process_moe_weights(
|
|
849
|
+
loaded_name, loaded_weight,
|
|
850
|
+
mlp_experts_up_proj_weights)
|
|
851
|
+
if scale is not None:
|
|
852
|
+
stacked_scales = self._process_moe_weights(
|
|
853
|
+
loaded_name, scale, mlp_experts_up_proj_scales)
|
|
854
|
+
if stacked_weights is not None:
|
|
855
|
+
weight_bytes, weight_shards = self._load_individual_weight(
|
|
856
|
+
loaded_name,
|
|
857
|
+
stacked_weights,
|
|
858
|
+
model_params,
|
|
859
|
+
model_for_loading.mesh,
|
|
860
|
+
scale=stacked_scales)
|
|
861
|
+
if self.is_verbose:
|
|
862
|
+
cumulative_global_memory += weight_bytes
|
|
863
|
+
cumulative_local_memory += weight_shards
|
|
864
|
+
logger.info(
|
|
865
|
+
f"Cumulative global memory: {cumulative_global_memory} GB"
|
|
866
|
+
)
|
|
867
|
+
logger.info(
|
|
868
|
+
f"Cumulative local memory: {cumulative_local_memory} GB"
|
|
869
|
+
)
|
|
870
|
+
else:
|
|
871
|
+
if self.use_mla_kernel and "kv_b_proj" in loaded_name:
|
|
872
|
+
# loaded_weight shape: (num_heads * (d_k + d_v), kv_lora_rank)
|
|
873
|
+
# scale shape: (num_heads * (d_k + d_v) / block_n, kv_lora_rank / block_k)
|
|
874
|
+
# Reshape to (num_heads, (d_k + d_v), kv_lora_rank) and split
|
|
875
|
+
weight_reshaped = loaded_weight.view(
|
|
876
|
+
self.attn_heads,
|
|
877
|
+
self.qk_nope_head_dim + self.v_head_dim,
|
|
878
|
+
self.kv_lora_rank)
|
|
879
|
+
k_weight = weight_reshaped[:, :self.
|
|
880
|
+
qk_nope_head_dim, :]
|
|
881
|
+
v_weight = weight_reshaped[:,
|
|
882
|
+
self.qk_nope_head_dim:, :]
|
|
883
|
+
|
|
884
|
+
loaded_weights_list = [k_weight, v_weight]
|
|
885
|
+
loaded_names = [
|
|
886
|
+
loaded_name.replace("kv_b_proj", "k_b_proj"),
|
|
887
|
+
loaded_name.replace("kv_b_proj", "v_b_proj")
|
|
888
|
+
]
|
|
889
|
+
|
|
890
|
+
scales_list = [None, None]
|
|
891
|
+
if scale is not None:
|
|
892
|
+
assert loaded_weight.shape[0] == scale.shape[0]
|
|
893
|
+
block_size_k = loaded_weight.shape[
|
|
894
|
+
1] // scale.shape[1]
|
|
895
|
+
assert block_size_k > 0, f"Expected non-zero block size but got {block_size_k}!"
|
|
896
|
+
scale_reshaped = scale.view(
|
|
897
|
+
self.attn_heads,
|
|
898
|
+
(self.qk_nope_head_dim + self.v_head_dim),
|
|
899
|
+
self.kv_lora_rank // block_size_k)
|
|
900
|
+
|
|
901
|
+
k_scale = scale_reshaped[:, :self.
|
|
902
|
+
qk_nope_head_dim, :]
|
|
903
|
+
v_scale = scale_reshaped[:,
|
|
904
|
+
self.qk_nope_head_dim:, :]
|
|
905
|
+
scales_list = [k_scale, v_scale]
|
|
906
|
+
|
|
907
|
+
else:
|
|
908
|
+
loaded_weights_list = [loaded_weight]
|
|
909
|
+
loaded_names = [loaded_name]
|
|
910
|
+
scales_list = [scale]
|
|
911
|
+
|
|
912
|
+
for loaded_name, loaded_weight, scale in zip(
|
|
913
|
+
loaded_names, loaded_weights_list, scales_list):
|
|
914
|
+
|
|
915
|
+
weight_bytes, weight_shards = self._load_individual_weight(
|
|
916
|
+
loaded_name,
|
|
917
|
+
loaded_weight,
|
|
918
|
+
model_params,
|
|
919
|
+
model_for_loading.mesh,
|
|
920
|
+
scale=scale)
|
|
921
|
+
if self.is_verbose:
|
|
922
|
+
cumulative_global_memory += weight_bytes
|
|
923
|
+
cumulative_local_memory += weight_shards
|
|
924
|
+
logger.info(
|
|
925
|
+
f"Cumulative global memory: {cumulative_global_memory} GB"
|
|
926
|
+
)
|
|
927
|
+
logger.info(
|
|
928
|
+
f"Cumulative local memory: {cumulative_local_memory} GB"
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
del mlp_experts_gate_proj_weights
|
|
932
|
+
del mlp_experts_up_proj_weights
|
|
933
|
+
del mlp_experts_down_proj_weights
|
|
934
|
+
del quantized_weights
|
|
935
|
+
del quantized_scales
|
|
936
|
+
# TODO: validate that all of the model_params were accounted for as well.
|
|
937
|
+
nnx.update(model_for_loading, model_params)
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
def weights_dequant_cpu(x: torch.Tensor,
|
|
941
|
+
s: torch.Tensor,
|
|
942
|
+
output_dtype: jnp.dtype,
|
|
943
|
+
block_size: int = 128) -> torch.Tensor:
|
|
944
|
+
assert x.dim() == 2 and s.dim() == 2, "Both x and s must be 2D tensors"
|
|
945
|
+
M, N = x.shape
|
|
946
|
+
|
|
947
|
+
x = x.to(torch.float32)
|
|
948
|
+
s = s.to(torch.float32)
|
|
949
|
+
y = torch.empty_like(x)
|
|
950
|
+
|
|
951
|
+
M_main = (M // block_size) * block_size
|
|
952
|
+
N_main = (N // block_size) * block_size
|
|
953
|
+
|
|
954
|
+
if M_main > 0 and N_main > 0:
|
|
955
|
+
x_main = x[:M_main, :N_main]
|
|
956
|
+
s_main = s[:(M // block_size), :(N // block_size)]
|
|
957
|
+
|
|
958
|
+
x_reshaped = x_main.view(M // block_size, block_size, N // block_size,
|
|
959
|
+
block_size).permute(0, 2, 1, 3)
|
|
960
|
+
s_reshaped = s_main.view(M // block_size, N // block_size, 1, 1)
|
|
961
|
+
y_main = (x_reshaped * s_reshaped).permute(0, 2, 1,
|
|
962
|
+
3).reshape(M_main, N_main)
|
|
963
|
+
|
|
964
|
+
y[:M_main, :N_main] = y_main
|
|
965
|
+
|
|
966
|
+
if N_main < N:
|
|
967
|
+
for i in range(0, M_main, block_size):
|
|
968
|
+
block = x[i:i + block_size, N_main:N]
|
|
969
|
+
scale = s[i // block_size, N // block_size]
|
|
970
|
+
y[i:i + block_size, N_main:N] = block * scale
|
|
971
|
+
|
|
972
|
+
if M_main < M:
|
|
973
|
+
for j in range(0, N, block_size):
|
|
974
|
+
block = x[M_main:M, j:j + block_size]
|
|
975
|
+
scale = s[M // block_size, j // block_size]
|
|
976
|
+
y[M_main:M, j:j + block_size] = block * scale
|
|
977
|
+
|
|
978
|
+
return y.to(j2t_dtype(jnp.dtype(output_dtype)))
|