tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
1
16
|
import re
|
|
2
17
|
from dataclasses import dataclass
|
|
3
18
|
from typing import List, Optional, Tuple
|
|
@@ -13,6 +28,8 @@ from torchax.ops.mappings import j2t_dtype
|
|
|
13
28
|
from vllm.config import VllmConfig
|
|
14
29
|
|
|
15
30
|
from tpu_inference import utils
|
|
31
|
+
from tpu_inference.layers.common.quantization import u8_unpack_e2m1
|
|
32
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
16
33
|
from tpu_inference.layers.jax.attention.attention import AttentionMetadata
|
|
17
34
|
from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
|
|
18
35
|
from tpu_inference.layers.jax.constants import KVCacheType
|
|
@@ -23,10 +40,8 @@ from tpu_inference.layers.jax.moe.moe import MoE
|
|
|
23
40
|
from tpu_inference.layers.jax.transformer_block import (
|
|
24
41
|
SharedExpertsTransformerBlock, TransformerBlock)
|
|
25
42
|
from tpu_inference.logger import init_logger
|
|
26
|
-
from tpu_inference.models.jax.utils.quantization.quantization_utils import \
|
|
27
|
-
get_quant_dtype_from_qwix_config
|
|
28
43
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
29
|
-
get_param, model_weights_generator, print_param_info
|
|
44
|
+
get_param, model_weights_generator, print_param_info)
|
|
30
45
|
|
|
31
46
|
logger = init_logger(__name__)
|
|
32
47
|
|
|
@@ -69,6 +84,9 @@ class DeepSeekV3(nnx.Module):
|
|
|
69
84
|
hidden_act: str = "silu"
|
|
70
85
|
rms_norm_eps: float = 1e-06
|
|
71
86
|
first_k_dense_replace: int = 3 # replace the first few MOE layers to dense layer.
|
|
87
|
+
self.use_mla_kernel: bool = self.vllm_config.model_config.use_mla
|
|
88
|
+
|
|
89
|
+
logger.info(f"Is using MLA kernel in DeepSeek: {self.use_mla_kernel}")
|
|
72
90
|
|
|
73
91
|
num_shared_experts = 1
|
|
74
92
|
rope_theta = 10000
|
|
@@ -114,19 +132,30 @@ class DeepSeekV3(nnx.Module):
|
|
|
114
132
|
qk_rope_head_dim=qk_rope_head_dim,
|
|
115
133
|
v_head_dim=v_head_dim,
|
|
116
134
|
num_local_experts=num_local_experts,
|
|
117
|
-
model_dtype=dtype
|
|
135
|
+
model_dtype=dtype,
|
|
136
|
+
use_mla_kernel=self.use_mla_kernel)
|
|
118
137
|
|
|
119
138
|
self.embedder = Embedder(vocab_size=vocab_size,
|
|
120
139
|
hidden_size=hidden_size,
|
|
121
140
|
dtype=dtype,
|
|
122
141
|
rngs=self.rng,
|
|
123
|
-
vd_sharding=(
|
|
142
|
+
vd_sharding=(ShardingAxisName.MLP_TENSOR,
|
|
124
143
|
None),
|
|
125
144
|
random_init=self.random_init)
|
|
126
145
|
|
|
127
146
|
self.layers = []
|
|
128
147
|
|
|
129
148
|
def _create_mla() -> MLA:
|
|
149
|
+
if self.use_mla_kernel:
|
|
150
|
+
query_tnh_spec = P(ShardingAxisName.MLP_TENSOR, None, None)
|
|
151
|
+
keyvalue_skh_spec = P(ShardingAxisName.MLP_TENSOR, None)
|
|
152
|
+
attn_o_tnh_spec = P(ShardingAxisName.MLP_TENSOR, None, None)
|
|
153
|
+
|
|
154
|
+
else:
|
|
155
|
+
query_tnh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
|
|
156
|
+
keyvalue_skh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
|
|
157
|
+
attn_o_tnh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
|
|
158
|
+
|
|
130
159
|
return MLA(
|
|
131
160
|
rope_theta=rope_theta,
|
|
132
161
|
rope_scaling=rope_scaling,
|
|
@@ -137,10 +166,12 @@ class DeepSeekV3(nnx.Module):
|
|
|
137
166
|
rms_norm_eps=rms_norm_eps,
|
|
138
167
|
v_head_dim=v_head_dim,
|
|
139
168
|
mesh=self.mesh,
|
|
169
|
+
use_mla_kernel=self.use_mla_kernel,
|
|
140
170
|
random_init=self.random_init,
|
|
141
171
|
hidden_size=hidden_size,
|
|
142
172
|
num_attention_heads=num_attention_heads,
|
|
143
|
-
num_key_value_heads=
|
|
173
|
+
num_key_value_heads=1
|
|
174
|
+
if self.use_mla_kernel else num_key_value_heads,
|
|
144
175
|
head_dim=v_head_dim, # MLA uses v_head_dim as head_dim
|
|
145
176
|
dtype=dtype,
|
|
146
177
|
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
@@ -148,14 +179,15 @@ class DeepSeekV3(nnx.Module):
|
|
|
148
179
|
rngs=self.rng,
|
|
149
180
|
activation_attention_td=(None, None),
|
|
150
181
|
activation_q_td=(None, None),
|
|
151
|
-
query_tnh=
|
|
152
|
-
keyvalue_skh=
|
|
182
|
+
query_tnh=query_tnh_spec,
|
|
183
|
+
keyvalue_skh=keyvalue_skh_spec,
|
|
153
184
|
activation_attention_out_td=(None, None),
|
|
154
|
-
attn_o_tnh=
|
|
155
|
-
q_da_sharding=(None,
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
185
|
+
attn_o_tnh=attn_o_tnh_spec,
|
|
186
|
+
q_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
187
|
+
ap_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
188
|
+
anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
|
|
189
|
+
kv_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
190
|
+
rd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
159
191
|
|
|
160
192
|
for i in range(first_k_dense_replace):
|
|
161
193
|
block = TransformerBlock(
|
|
@@ -176,14 +208,15 @@ class DeepSeekV3(nnx.Module):
|
|
|
176
208
|
rngs=self.rng,
|
|
177
209
|
),
|
|
178
210
|
attn=_create_mla(),
|
|
179
|
-
custom_module=DenseFFW(
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
211
|
+
custom_module=DenseFFW(
|
|
212
|
+
dtype=dtype,
|
|
213
|
+
hidden_act=hidden_act,
|
|
214
|
+
hidden_size=hidden_size,
|
|
215
|
+
intermediate_size=ffw_intermediate_size,
|
|
216
|
+
rngs=self.rng,
|
|
217
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
218
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None),
|
|
219
|
+
random_init=self.random_init))
|
|
187
220
|
|
|
188
221
|
self.layers.append(block)
|
|
189
222
|
|
|
@@ -200,9 +233,9 @@ class DeepSeekV3(nnx.Module):
|
|
|
200
233
|
rngs=self.rng,
|
|
201
234
|
routed_scaling_factor=2.5,
|
|
202
235
|
dtype=dtype,
|
|
203
|
-
activation_ffw_td=(
|
|
204
|
-
ed_sharding=(
|
|
205
|
-
e_sharding=(
|
|
236
|
+
activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
|
|
237
|
+
ed_sharding=(ShardingAxisName.MLP_TENSOR, None),
|
|
238
|
+
e_sharding=(ShardingAxisName.MLP_TENSOR, ))
|
|
206
239
|
if self.sparse_matmul:
|
|
207
240
|
# TODO: orginize the SparseMoE and DenseMoE better given they share most interfaces
|
|
208
241
|
custom_module = SparseMoE(
|
|
@@ -216,10 +249,10 @@ class DeepSeekV3(nnx.Module):
|
|
|
216
249
|
hidden_act=hidden_act,
|
|
217
250
|
rngs=self.rng,
|
|
218
251
|
random_init=self.random_init,
|
|
219
|
-
activation_ffw_td=(
|
|
220
|
-
activation_ffw_ted=(
|
|
221
|
-
edf_sharding=(
|
|
222
|
-
efd_sharding=(
|
|
252
|
+
activation_ffw_td=(ShardingAxisName.MLP_TENSOR, None),
|
|
253
|
+
activation_ffw_ted=(ShardingAxisName.MLP_DATA, None, None),
|
|
254
|
+
edf_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
255
|
+
efd_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
223
256
|
quantized_dtype=self.weight_loader.quant_dtype
|
|
224
257
|
if self.weight_loader.is_model_quantized else None,
|
|
225
258
|
router=router) if is_moe_layer else DenseFFW(
|
|
@@ -229,8 +262,8 @@ class DeepSeekV3(nnx.Module):
|
|
|
229
262
|
intermediate_size=ffw_intermediate_size,
|
|
230
263
|
rngs=self.rng,
|
|
231
264
|
random_init=self.random_init,
|
|
232
|
-
df_sharding=(None,
|
|
233
|
-
fd_sharding=(
|
|
265
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
266
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
234
267
|
else:
|
|
235
268
|
custom_module = MoE(
|
|
236
269
|
dtype=dtype,
|
|
@@ -241,10 +274,10 @@ class DeepSeekV3(nnx.Module):
|
|
|
241
274
|
hidden_act=hidden_act,
|
|
242
275
|
rngs=self.rng,
|
|
243
276
|
random_init=self.random_init,
|
|
244
|
-
activation_ffw_td=(
|
|
245
|
-
activation_ffw_ted=(
|
|
246
|
-
edf_sharding=(
|
|
247
|
-
efd_sharding=(
|
|
277
|
+
activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
|
|
278
|
+
activation_ffw_ted=(ShardingAxisName.MLP_DATA, None, None),
|
|
279
|
+
edf_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
280
|
+
efd_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
248
281
|
router=router) if is_moe_layer else DenseFFW(
|
|
249
282
|
dtype=dtype,
|
|
250
283
|
hidden_act=hidden_act,
|
|
@@ -252,18 +285,18 @@ class DeepSeekV3(nnx.Module):
|
|
|
252
285
|
intermediate_size=ffw_intermediate_size,
|
|
253
286
|
rngs=self.rng,
|
|
254
287
|
random_init=self.random_init,
|
|
255
|
-
df_sharding=(None,
|
|
256
|
-
fd_sharding=(
|
|
257
|
-
|
|
258
|
-
shared_experts = DenseFFW(
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
288
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
289
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
290
|
+
|
|
291
|
+
shared_experts = DenseFFW(
|
|
292
|
+
dtype=dtype,
|
|
293
|
+
hidden_act=hidden_act,
|
|
294
|
+
hidden_size=hidden_size,
|
|
295
|
+
intermediate_size=num_shared_experts * moe_intermediate_size,
|
|
296
|
+
rngs=self.rng,
|
|
297
|
+
random_init=self.random_init,
|
|
298
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
299
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
267
300
|
|
|
268
301
|
pre_attention_norm = RMSNorm(
|
|
269
302
|
dims=hidden_size,
|
|
@@ -304,10 +337,28 @@ class DeepSeekV3(nnx.Module):
|
|
|
304
337
|
hidden_size=hidden_size,
|
|
305
338
|
dtype=dtype,
|
|
306
339
|
rngs=self.rng,
|
|
307
|
-
vd_sharding=(
|
|
308
|
-
dv_sharding=(None,
|
|
340
|
+
vd_sharding=(ShardingAxisName.MLP_TENSOR, None),
|
|
341
|
+
dv_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
309
342
|
random_init=self.random_init)
|
|
310
343
|
|
|
344
|
+
if os.environ.get("VLLM_LOGGING_LEVEL", "").upper() == "DEBUG":
|
|
345
|
+
self._print_model_architecture()
|
|
346
|
+
|
|
347
|
+
def _print_model_architecture(self):
|
|
348
|
+
num_display_layers = 5
|
|
349
|
+
|
|
350
|
+
logger.debug("### Embedding ###")
|
|
351
|
+
nnx.display(self.embedder)
|
|
352
|
+
|
|
353
|
+
logger.debug(f"\n### First {num_display_layers} Layers ###")
|
|
354
|
+
# Loop through the slice and display each layer
|
|
355
|
+
for i, layer in enumerate(self.layers[:num_display_layers]):
|
|
356
|
+
logger.debug(f"\n--- Layer {i} ---")
|
|
357
|
+
nnx.display(layer)
|
|
358
|
+
|
|
359
|
+
logger.debug("\n### LM Head ###")
|
|
360
|
+
nnx.display(self.lm_head)
|
|
361
|
+
|
|
311
362
|
# For compatibility with flax.
|
|
312
363
|
def apply(self, variables, *args, **kwargs):
|
|
313
364
|
return self.__call__(*args, **kwargs)
|
|
@@ -352,10 +403,19 @@ class DeepSeekV3(nnx.Module):
|
|
|
352
403
|
@dataclass
|
|
353
404
|
class DeepSeekV3WeightLoader:
|
|
354
405
|
|
|
355
|
-
def __init__(self,
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
406
|
+
def __init__(self,
|
|
407
|
+
vllm_config: VllmConfig,
|
|
408
|
+
num_layers,
|
|
409
|
+
hidden_size,
|
|
410
|
+
q_lora_rank,
|
|
411
|
+
kv_lora_rank,
|
|
412
|
+
attn_heads,
|
|
413
|
+
qk_nope_head_dim,
|
|
414
|
+
qk_rope_head_dim,
|
|
415
|
+
v_head_dim,
|
|
416
|
+
num_local_experts,
|
|
417
|
+
model_dtype,
|
|
418
|
+
use_mla_kernel=False):
|
|
359
419
|
self.num_layers = num_layers
|
|
360
420
|
self.names_and_weights_generator = model_weights_generator(
|
|
361
421
|
model_name_or_path=vllm_config.model_config.model,
|
|
@@ -364,7 +424,12 @@ class DeepSeekV3WeightLoader:
|
|
|
364
424
|
self.is_verbose = vllm_config.additional_config.get(
|
|
365
425
|
"is_verbose", None) is not None
|
|
366
426
|
self.num_routed_experts = num_local_experts
|
|
427
|
+
self.attn_heads = attn_heads
|
|
428
|
+
self.qk_nope_head_dim = qk_nope_head_dim
|
|
429
|
+
self.v_head_dim = v_head_dim
|
|
430
|
+
self.kv_lora_rank = kv_lora_rank
|
|
367
431
|
self.model_dtype = model_dtype
|
|
432
|
+
self.use_mla_kernel = use_mla_kernel
|
|
368
433
|
|
|
369
434
|
self._transpose_map = {
|
|
370
435
|
# dense mlp
|
|
@@ -373,10 +438,12 @@ class DeepSeekV3WeightLoader:
|
|
|
373
438
|
r"mlp\.up_proj": (1, 0),
|
|
374
439
|
# mla
|
|
375
440
|
r"q_a_proj": (1, 0),
|
|
376
|
-
r"q_b_proj": (
|
|
441
|
+
r"q_b_proj": (1, 0),
|
|
377
442
|
r"kv_a_proj_with_mqa": (1, 0),
|
|
378
|
-
r"kv_b_proj": (
|
|
379
|
-
r"
|
|
443
|
+
r"kv_b_proj": (1, 0),
|
|
444
|
+
r"k_b_proj": (2, 0, 1), # used for MLA kernel
|
|
445
|
+
r"v_b_proj": (2, 0, 1), # used for MLA kernel
|
|
446
|
+
r"o_proj": (1, 0),
|
|
380
447
|
# moe
|
|
381
448
|
r"mlp\.gate\.weight": (1, 0),
|
|
382
449
|
r"mlp\.experts\.\d+\.gate_proj": (0, 2, 1),
|
|
@@ -388,13 +455,6 @@ class DeepSeekV3WeightLoader:
|
|
|
388
455
|
# lm_head
|
|
389
456
|
r"lm_head\.weight": (1, 0)
|
|
390
457
|
}
|
|
391
|
-
self._weight_shape_map = {
|
|
392
|
-
"q_b_proj":
|
|
393
|
-
(attn_heads, qk_nope_head_dim + qk_rope_head_dim, q_lora_rank),
|
|
394
|
-
"kv_b_proj":
|
|
395
|
-
(attn_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank),
|
|
396
|
-
"o_proj": (hidden_size, attn_heads, v_head_dim)
|
|
397
|
-
}
|
|
398
458
|
|
|
399
459
|
# Set the mappings from loaded parameter keys to standardized names.
|
|
400
460
|
self._loaded_to_standardized_keys = {
|
|
@@ -419,13 +479,13 @@ class DeepSeekV3WeightLoader:
|
|
|
419
479
|
"model.layers.*.self_attn.q_a_proj.weight":
|
|
420
480
|
"layers.*.attn.kernel_q_down_proj_DA",
|
|
421
481
|
"model.layers.*.self_attn.q_b_proj.weight":
|
|
422
|
-
"layers.*.attn.
|
|
482
|
+
"layers.*.attn.kernel_q_up_proj_AP",
|
|
423
483
|
"model.layers.*.self_attn.kv_a_proj_with_mqa.weight":
|
|
424
484
|
"layers.*.attn.kernel_kv_down_proj_DA",
|
|
425
485
|
"model.layers.*.self_attn.kv_b_proj.weight":
|
|
426
|
-
"layers.*.attn.
|
|
486
|
+
"layers.*.attn.kernel_kv_up_proj_AL",
|
|
427
487
|
"model.layers.*.self_attn.o_proj.weight":
|
|
428
|
-
"layers.*.attn.
|
|
488
|
+
"layers.*.attn.kernel_o_proj_RD",
|
|
429
489
|
# Dense ffw
|
|
430
490
|
"model.layers.*.mlp.gate_proj.weight":
|
|
431
491
|
"layers.*.custom_module.kernel_gating_DF",
|
|
@@ -452,57 +512,50 @@ class DeepSeekV3WeightLoader:
|
|
|
452
512
|
"model.layers.*.mlp.shared_experts.up_proj.weight":
|
|
453
513
|
"layers.*.shared_experts.kernel_up_proj_DF",
|
|
454
514
|
}
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
515
|
+
if self.use_mla_kernel:
|
|
516
|
+
self._loaded_to_standardized_keys.update({
|
|
517
|
+
"model.layers.*.self_attn.k_b_proj.weight":
|
|
518
|
+
"layers.*.attn.kernel_k_up_proj_ANH",
|
|
519
|
+
"model.layers.*.self_attn.v_b_proj.weight":
|
|
520
|
+
"layers.*.attn.kernel_v_up_proj_ANH",
|
|
521
|
+
})
|
|
522
|
+
# TODO (jacobplatin): we should not be hard-coding these
|
|
523
|
+
self.scale_dtype, self.quant_dtype = jnp.bfloat16, jnp.float8_e4m3fn
|
|
459
524
|
|
|
460
525
|
self.is_model_quantized = not vllm_config.additional_config.get(
|
|
461
526
|
"skip_quantization", False)
|
|
462
|
-
if self.is_model_quantized:
|
|
463
|
-
# TODO (jacobplatin): expand support eventually
|
|
464
|
-
quantization_type = vllm_config.model_config.hf_config.quantization_config[
|
|
465
|
-
"quant_method"]
|
|
466
|
-
assert quantization_type == "fp8", "DeepSeek only supports the fp8 quantization method for now"
|
|
467
|
-
self.scale_dtype, self.quant_dtype = get_quant_dtype_from_qwix_config(
|
|
468
|
-
vllm_config)
|
|
469
|
-
|
|
470
|
-
logger.info(
|
|
471
|
-
f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
|
|
472
|
-
)
|
|
473
527
|
|
|
474
|
-
|
|
475
|
-
"weight_block_size"]
|
|
476
|
-
assert len(
|
|
477
|
-
quantization_block_sizes
|
|
478
|
-
) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
|
|
479
|
-
self.quantization_block_size_n = quantization_block_sizes[0]
|
|
480
|
-
self.quantization_block_size_k = quantization_block_sizes[1]
|
|
481
|
-
# TODO (jacobplatin): remove this check in the future
|
|
482
|
-
assert self.quantization_block_size_n == self.quantization_block_size_k, "Quantization block size n and k must be the same!"
|
|
483
|
-
# NOTE: this is only needed for pre-quantized models
|
|
484
|
-
self._scale_shape_map = {
|
|
485
|
-
"q_b_proj": (1, qk_nope_head_dim + qk_rope_head_dim,
|
|
486
|
-
q_lora_rank // self.quantization_block_size_n),
|
|
487
|
-
"kv_b_proj": (attn_heads, (qk_nope_head_dim + v_head_dim) //
|
|
488
|
-
self.quantization_block_size_n,
|
|
489
|
-
kv_lora_rank // self.quantization_block_size_n),
|
|
490
|
-
"o_proj":
|
|
491
|
-
(hidden_size // self.quantization_block_size_n, attn_heads,
|
|
492
|
-
v_head_dim // self.quantization_block_size_n),
|
|
493
|
-
}
|
|
528
|
+
if self.is_model_quantized:
|
|
494
529
|
# NOTE: this is only needed for pre-quantized models when doing random weight loading
|
|
530
|
+
# because the scales that Qwix configures by default don't necessarily match the
|
|
531
|
+
# scales in practice
|
|
495
532
|
# TODO (jacobplatin): remove or clean this up
|
|
496
|
-
self.
|
|
497
|
-
|
|
498
|
-
"
|
|
499
|
-
"
|
|
500
|
-
"
|
|
501
|
-
|
|
502
|
-
"
|
|
503
|
-
"
|
|
533
|
+
self.scale_shape_map_for_random_weight_loading = {
|
|
534
|
+
# MoE experts (3D)
|
|
535
|
+
"custom_module.kernel_down_proj_EFD": (256, 8, 7168),
|
|
536
|
+
"custom_module.kernel_gating_EDF": (256, 28, 2048),
|
|
537
|
+
"custom_module.kernel_up_proj_EDF": (256, 28, 2048),
|
|
538
|
+
# Shared experts (2D)
|
|
539
|
+
"shared_experts.kernel_down_proj_FD": (8, 7168),
|
|
540
|
+
"shared_experts.kernel_gating_DF": (28, 2048),
|
|
541
|
+
"shared_experts.kernel_up_proj_DF": (28, 2048),
|
|
542
|
+
# Dense FFW (2D)
|
|
543
|
+
"custom_module.kernel_gating_DF": (28, 18432),
|
|
544
|
+
"custom_module.kernel_up_proj_DF": (28, 18432),
|
|
545
|
+
"custom_module.kernel_down_proj_FD": (72, 7168),
|
|
546
|
+
# Attention (3D for MLA, 2D for the rest)
|
|
547
|
+
"attn.kernel_q_down_proj_DA": (28, 1536),
|
|
548
|
+
"attn.kernel_q_up_proj_AP": (6, 24576),
|
|
549
|
+
"attn.kernel_kv_down_proj_DA": (28, 576),
|
|
550
|
+
"attn.kernel_kv_up_proj_AL": (2, 32768),
|
|
551
|
+
"attn.kernel_o_proj_RD": (64, 7168),
|
|
552
|
+
"attn.kernel_k_up_proj_ANH": (2, 128, 128), # MLA
|
|
553
|
+
"attn.kernel_v_up_proj_ANH": (2, 128, 128), # MLA
|
|
504
554
|
}
|
|
505
555
|
|
|
556
|
+
# TODO (jacobplatin): remove this check eventually!
|
|
557
|
+
assert self.quant_dtype == jnp.float8_e4m3fn, f"Expected quant_dtype to be float8_e4m3fn for DeepSeek but got {self.quant_dtype}"
|
|
558
|
+
|
|
506
559
|
def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
|
|
507
560
|
# Find the corresponding model key using the HF key
|
|
508
561
|
if "layer" in loaded_key:
|
|
@@ -580,45 +633,56 @@ class DeepSeekV3WeightLoader:
|
|
|
580
633
|
base_model_weight, "array") else base_model_weight.sharding
|
|
581
634
|
|
|
582
635
|
# Convert weights from torch into numpy
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
# Avoid unnecessary upcasting and mem copy by viewing the tensor's
|
|
589
|
-
# raw data as integers before converting to a JAX array.
|
|
590
|
-
weight_np = jnp.array(
|
|
591
|
-
weight.view(torch_view_type).numpy()).view(cast_type)
|
|
636
|
+
if weight.dtype == torch.uint8 and scale is not None:
|
|
637
|
+
# Assume packed FP4 format when uint8 weights with scale provided
|
|
638
|
+
weight_jax_u8 = jnp.array(weight.cpu().numpy())
|
|
639
|
+
weight_np = u8_unpack_e2m1(weight_jax_u8)
|
|
640
|
+
scale = scale.to(torch.float32).numpy().astype(self.scale_dtype)
|
|
592
641
|
else:
|
|
593
|
-
|
|
594
|
-
|
|
642
|
+
cast_type = model_weight.value.dtype
|
|
643
|
+
# Special-case: FP4 values stored as FP8 for compatibility.
|
|
644
|
+
# If the model expects float4_e2m1fn but the checkpoint provides FP8,
|
|
645
|
+
# convert by numeric value (float32) then cast to float4.
|
|
646
|
+
if cast_type == jnp.float4_e2m1fn and weight.dtype == torch.float8_e4m3fn:
|
|
647
|
+
weight_np = jnp.array(weight.float().numpy()).astype(cast_type)
|
|
648
|
+
else:
|
|
649
|
+
torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
|
|
595
650
|
|
|
596
|
-
|
|
597
|
-
|
|
651
|
+
if torch_view_type:
|
|
652
|
+
# Avoid unnecessary upcasting and mem copy by viewing the tensor's
|
|
653
|
+
# raw data as integers before converting to a JAX array.
|
|
654
|
+
weight_np = jnp.array(
|
|
655
|
+
weight.view(torch_view_type).numpy()).view(cast_type)
|
|
656
|
+
else:
|
|
657
|
+
raise ValueError(
|
|
658
|
+
f"Unsupported dtype for tensor conversion: {cast_type}"
|
|
659
|
+
)
|
|
598
660
|
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
scale = reshape_params(name, scale, self._scale_shape_map)
|
|
661
|
+
if scale is not None:
|
|
662
|
+
scale = scale.to(torch.float32).numpy().astype(
|
|
663
|
+
self.scale_dtype)
|
|
603
664
|
weight_np = self._transpose_params(name, weight_np)
|
|
604
665
|
if scale is not None:
|
|
605
666
|
scale = self._transpose_params(name, scale)
|
|
667
|
+
# Ensure scale is broadcastable to weight_np by repeating per-axis.
|
|
606
668
|
weight_shape = weight_np.shape
|
|
607
669
|
scale_shape = scale.shape
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
670
|
+
if len(weight_shape) == len(scale_shape):
|
|
671
|
+
new_scale = scale
|
|
672
|
+
for wdim, sdim in zip(weight_shape, scale_shape):
|
|
673
|
+
if (wdim % sdim != 0):
|
|
674
|
+
raise ValueError(
|
|
675
|
+
f"Weight dim {wdim} is not divisible by scale dim {sdim} for weight {name} with shape {weight_shape} and scale {scale_shape}!"
|
|
676
|
+
)
|
|
677
|
+
if scale_shape != new_scale.shape:
|
|
615
678
|
logger.warning(
|
|
616
|
-
f"
|
|
617
|
-
f"where the scale_dim {scale_dim} does not match the weight_dim {weight_dim} "
|
|
618
|
-
f"multiplied by the quantization block size {self.quantization_block_size_n}. "
|
|
619
|
-
f"Repeating the scale to new shape {scale.shape} along axis {idx} with repeat size {self.quantization_block_size_n}."
|
|
679
|
+
f"Adjusted scale shape {scale_shape} to {new_scale.shape} to match weight {weight_shape}"
|
|
620
680
|
)
|
|
621
|
-
|
|
681
|
+
scale = new_scale
|
|
682
|
+
else:
|
|
683
|
+
raise ValueError(
|
|
684
|
+
f"Scale rank {scale_shape} does not match weight rank {weight_shape}"
|
|
685
|
+
)
|
|
622
686
|
|
|
623
687
|
if model_weight.value.shape != weight_np.shape:
|
|
624
688
|
raise ValueError(
|
|
@@ -652,10 +716,8 @@ class DeepSeekV3WeightLoader:
|
|
|
652
716
|
logger.warning(
|
|
653
717
|
f"Could not create sharded scale for {name} with shape {scale.shape} and sharding {sharding}, skipping sharding..."
|
|
654
718
|
)
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
assert base_model_weight.array.scale.value.dtype == maybe_sharded_scale.dtype, "Expected dtype for model weight scale with name {mapped_name} and dtype ({base_model_weight.array.scale.value.dtype}) to match that of the incoming weight scale ({maybe_sharded_scale.dtype})"
|
|
658
|
-
assert base_model_weight.array.qvalue.value.dtype == sharded_array.dtype, "Expected dtype for model weight with name {mapped_name} and dtype ({base_model_weight.array.qvalue.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
|
|
719
|
+
assert base_model_weight.array.scale.value.dtype == maybe_sharded_scale.dtype, f"Expected dtype for model weight scale with name {mapped_name} and dtype ({base_model_weight.array.scale.value.dtype}) to match that of the incoming weight scale ({maybe_sharded_scale.dtype})"
|
|
720
|
+
assert base_model_weight.array.qvalue.value.dtype == sharded_array.dtype, f"Expected dtype for model weight with name {mapped_name} and dtype ({base_model_weight.array.qvalue.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
|
|
659
721
|
base_model_weight.array.scale.value = maybe_sharded_scale
|
|
660
722
|
base_model_weight.array.qvalue.value = sharded_array
|
|
661
723
|
else:
|
|
@@ -721,7 +783,11 @@ class DeepSeekV3WeightLoader:
|
|
|
721
783
|
# TODO (jacobplatin): refactor this so that we instead change / update `model_weights_generator`
|
|
722
784
|
# instead of checking "weight_scale_inv" and assuming quantization method is fp8
|
|
723
785
|
scale = None
|
|
724
|
-
|
|
786
|
+
# Mixed quantization: accept both fp8 and packed fp4 (uint8) tensors
|
|
787
|
+
allowed_quant_dtypes = {
|
|
788
|
+
j2t_dtype(self.quant_dtype.dtype), torch.uint8
|
|
789
|
+
}
|
|
790
|
+
if loaded_weight.dtype in allowed_quant_dtypes:
|
|
725
791
|
if self.is_model_quantized:
|
|
726
792
|
scale_name = loaded_name.replace(
|
|
727
793
|
".weight", ".weight_scale_inv")
|
|
@@ -802,21 +868,65 @@ class DeepSeekV3WeightLoader:
|
|
|
802
868
|
f"Cumulative local memory: {cumulative_local_memory} GB"
|
|
803
869
|
)
|
|
804
870
|
else:
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
871
|
+
if self.use_mla_kernel and "kv_b_proj" in loaded_name:
|
|
872
|
+
# loaded_weight shape: (num_heads * (d_k + d_v), kv_lora_rank)
|
|
873
|
+
# scale shape: (num_heads * (d_k + d_v) / block_n, kv_lora_rank / block_k)
|
|
874
|
+
# Reshape to (num_heads, (d_k + d_v), kv_lora_rank) and split
|
|
875
|
+
weight_reshaped = loaded_weight.view(
|
|
876
|
+
self.attn_heads,
|
|
877
|
+
self.qk_nope_head_dim + self.v_head_dim,
|
|
878
|
+
self.kv_lora_rank)
|
|
879
|
+
k_weight = weight_reshaped[:, :self.
|
|
880
|
+
qk_nope_head_dim, :]
|
|
881
|
+
v_weight = weight_reshaped[:,
|
|
882
|
+
self.qk_nope_head_dim:, :]
|
|
883
|
+
|
|
884
|
+
loaded_weights_list = [k_weight, v_weight]
|
|
885
|
+
loaded_names = [
|
|
886
|
+
loaded_name.replace("kv_b_proj", "k_b_proj"),
|
|
887
|
+
loaded_name.replace("kv_b_proj", "v_b_proj")
|
|
888
|
+
]
|
|
889
|
+
|
|
890
|
+
scales_list = [None, None]
|
|
891
|
+
if scale is not None:
|
|
892
|
+
assert loaded_weight.shape[0] == scale.shape[0]
|
|
893
|
+
block_size_k = loaded_weight.shape[
|
|
894
|
+
1] // scale.shape[1]
|
|
895
|
+
assert block_size_k > 0, f"Expected non-zero block size but got {block_size_k}!"
|
|
896
|
+
scale_reshaped = scale.view(
|
|
897
|
+
self.attn_heads,
|
|
898
|
+
(self.qk_nope_head_dim + self.v_head_dim),
|
|
899
|
+
self.kv_lora_rank // block_size_k)
|
|
900
|
+
|
|
901
|
+
k_scale = scale_reshaped[:, :self.
|
|
902
|
+
qk_nope_head_dim, :]
|
|
903
|
+
v_scale = scale_reshaped[:,
|
|
904
|
+
self.qk_nope_head_dim:, :]
|
|
905
|
+
scales_list = [k_scale, v_scale]
|
|
906
|
+
|
|
907
|
+
else:
|
|
908
|
+
loaded_weights_list = [loaded_weight]
|
|
909
|
+
loaded_names = [loaded_name]
|
|
910
|
+
scales_list = [scale]
|
|
911
|
+
|
|
912
|
+
for loaded_name, loaded_weight, scale in zip(
|
|
913
|
+
loaded_names, loaded_weights_list, scales_list):
|
|
914
|
+
|
|
915
|
+
weight_bytes, weight_shards = self._load_individual_weight(
|
|
916
|
+
loaded_name,
|
|
917
|
+
loaded_weight,
|
|
918
|
+
model_params,
|
|
919
|
+
model_for_loading.mesh,
|
|
920
|
+
scale=scale)
|
|
921
|
+
if self.is_verbose:
|
|
922
|
+
cumulative_global_memory += weight_bytes
|
|
923
|
+
cumulative_local_memory += weight_shards
|
|
924
|
+
logger.info(
|
|
925
|
+
f"Cumulative global memory: {cumulative_global_memory} GB"
|
|
926
|
+
)
|
|
927
|
+
logger.info(
|
|
928
|
+
f"Cumulative local memory: {cumulative_local_memory} GB"
|
|
929
|
+
)
|
|
820
930
|
|
|
821
931
|
del mlp_experts_gate_proj_weights
|
|
822
932
|
del mlp_experts_up_proj_weights
|