tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
1
16
  import re
2
17
  from dataclasses import dataclass
3
18
  from typing import List, Optional, Tuple
@@ -13,6 +28,8 @@ from torchax.ops.mappings import j2t_dtype
13
28
  from vllm.config import VllmConfig
14
29
 
15
30
  from tpu_inference import utils
31
+ from tpu_inference.layers.common.quantization import u8_unpack_e2m1
32
+ from tpu_inference.layers.common.sharding import ShardingAxisName
16
33
  from tpu_inference.layers.jax.attention.attention import AttentionMetadata
17
34
  from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
18
35
  from tpu_inference.layers.jax.constants import KVCacheType
@@ -23,10 +40,8 @@ from tpu_inference.layers.jax.moe.moe import MoE
23
40
  from tpu_inference.layers.jax.transformer_block import (
24
41
  SharedExpertsTransformerBlock, TransformerBlock)
25
42
  from tpu_inference.logger import init_logger
26
- from tpu_inference.models.jax.utils.quantization.quantization_utils import \
27
- get_quant_dtype_from_qwix_config
28
43
  from tpu_inference.models.jax.utils.weight_utils import (
29
- get_param, model_weights_generator, print_param_info, reshape_params)
44
+ get_param, model_weights_generator, print_param_info)
30
45
 
31
46
  logger = init_logger(__name__)
32
47
 
@@ -69,6 +84,9 @@ class DeepSeekV3(nnx.Module):
69
84
  hidden_act: str = "silu"
70
85
  rms_norm_eps: float = 1e-06
71
86
  first_k_dense_replace: int = 3 # replace the first few MOE layers to dense layer.
87
+ self.use_mla_kernel: bool = self.vllm_config.model_config.use_mla
88
+
89
+ logger.info(f"Is using MLA kernel in DeepSeek: {self.use_mla_kernel}")
72
90
 
73
91
  num_shared_experts = 1
74
92
  rope_theta = 10000
@@ -114,19 +132,30 @@ class DeepSeekV3(nnx.Module):
114
132
  qk_rope_head_dim=qk_rope_head_dim,
115
133
  v_head_dim=v_head_dim,
116
134
  num_local_experts=num_local_experts,
117
- model_dtype=dtype)
135
+ model_dtype=dtype,
136
+ use_mla_kernel=self.use_mla_kernel)
118
137
 
119
138
  self.embedder = Embedder(vocab_size=vocab_size,
120
139
  hidden_size=hidden_size,
121
140
  dtype=dtype,
122
141
  rngs=self.rng,
123
- vd_sharding=(('data', 'expert', 'model'),
142
+ vd_sharding=(ShardingAxisName.MLP_TENSOR,
124
143
  None),
125
144
  random_init=self.random_init)
126
145
 
127
146
  self.layers = []
128
147
 
129
148
  def _create_mla() -> MLA:
149
+ if self.use_mla_kernel:
150
+ query_tnh_spec = P(ShardingAxisName.MLP_TENSOR, None, None)
151
+ keyvalue_skh_spec = P(ShardingAxisName.MLP_TENSOR, None)
152
+ attn_o_tnh_spec = P(ShardingAxisName.MLP_TENSOR, None, None)
153
+
154
+ else:
155
+ query_tnh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
156
+ keyvalue_skh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
157
+ attn_o_tnh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
158
+
130
159
  return MLA(
131
160
  rope_theta=rope_theta,
132
161
  rope_scaling=rope_scaling,
@@ -137,10 +166,12 @@ class DeepSeekV3(nnx.Module):
137
166
  rms_norm_eps=rms_norm_eps,
138
167
  v_head_dim=v_head_dim,
139
168
  mesh=self.mesh,
169
+ use_mla_kernel=self.use_mla_kernel,
140
170
  random_init=self.random_init,
141
171
  hidden_size=hidden_size,
142
172
  num_attention_heads=num_attention_heads,
143
- num_key_value_heads=num_key_value_heads,
173
+ num_key_value_heads=1
174
+ if self.use_mla_kernel else num_key_value_heads,
144
175
  head_dim=v_head_dim, # MLA uses v_head_dim as head_dim
145
176
  dtype=dtype,
146
177
  # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
@@ -148,14 +179,15 @@ class DeepSeekV3(nnx.Module):
148
179
  rngs=self.rng,
149
180
  activation_attention_td=(None, None),
150
181
  activation_q_td=(None, None),
151
- query_tnh=P(None, 'model', None),
152
- keyvalue_skh=P(None, 'model', None),
182
+ query_tnh=query_tnh_spec,
183
+ keyvalue_skh=keyvalue_skh_spec,
153
184
  activation_attention_out_td=(None, None),
154
- attn_o_tnh=P(None, 'model', None),
155
- q_da_sharding=(None, 'model'),
156
- anh_sharding=(None, 'model', None),
157
- kv_da_sharding=(None, 'model'),
158
- nhd_sharding=('model', None, None))
185
+ attn_o_tnh=attn_o_tnh_spec,
186
+ q_da_sharding=(None, ShardingAxisName.VOCAB),
187
+ ap_sharding=(None, ShardingAxisName.MLP_TENSOR),
188
+ anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
189
+ kv_da_sharding=(None, ShardingAxisName.VOCAB),
190
+ rd_sharding=(ShardingAxisName.MLP_TENSOR, None))
159
191
 
160
192
  for i in range(first_k_dense_replace):
161
193
  block = TransformerBlock(
@@ -176,14 +208,15 @@ class DeepSeekV3(nnx.Module):
176
208
  rngs=self.rng,
177
209
  ),
178
210
  attn=_create_mla(),
179
- custom_module=DenseFFW(dtype=dtype,
180
- hidden_act=hidden_act,
181
- hidden_size=hidden_size,
182
- intermediate_size=ffw_intermediate_size,
183
- rngs=self.rng,
184
- df_sharding=(None, ('model', 'expert')),
185
- fd_sharding=(('model', 'expert'), None),
186
- random_init=self.random_init))
211
+ custom_module=DenseFFW(
212
+ dtype=dtype,
213
+ hidden_act=hidden_act,
214
+ hidden_size=hidden_size,
215
+ intermediate_size=ffw_intermediate_size,
216
+ rngs=self.rng,
217
+ df_sharding=(None, ShardingAxisName.MLP_TENSOR),
218
+ fd_sharding=(ShardingAxisName.MLP_TENSOR, None),
219
+ random_init=self.random_init))
187
220
 
188
221
  self.layers.append(block)
189
222
 
@@ -200,9 +233,9 @@ class DeepSeekV3(nnx.Module):
200
233
  rngs=self.rng,
201
234
  routed_scaling_factor=2.5,
202
235
  dtype=dtype,
203
- activation_ffw_td=('data', None),
204
- ed_sharding=('model', None),
205
- e_sharding=('model', ))
236
+ activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
237
+ ed_sharding=(ShardingAxisName.MLP_TENSOR, None),
238
+ e_sharding=(ShardingAxisName.MLP_TENSOR, ))
206
239
  if self.sparse_matmul:
207
240
  # TODO: orginize the SparseMoE and DenseMoE better given they share most interfaces
208
241
  custom_module = SparseMoE(
@@ -216,10 +249,10 @@ class DeepSeekV3(nnx.Module):
216
249
  hidden_act=hidden_act,
217
250
  rngs=self.rng,
218
251
  random_init=self.random_init,
219
- activation_ffw_td=('data', None),
220
- activation_ffw_ted=('data', None, None),
221
- edf_sharding=('model', None, None),
222
- efd_sharding=('model', None, None),
252
+ activation_ffw_td=(ShardingAxisName.MLP_TENSOR, None),
253
+ activation_ffw_ted=(ShardingAxisName.MLP_DATA, None, None),
254
+ edf_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
255
+ efd_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
223
256
  quantized_dtype=self.weight_loader.quant_dtype
224
257
  if self.weight_loader.is_model_quantized else None,
225
258
  router=router) if is_moe_layer else DenseFFW(
@@ -229,8 +262,8 @@ class DeepSeekV3(nnx.Module):
229
262
  intermediate_size=ffw_intermediate_size,
230
263
  rngs=self.rng,
231
264
  random_init=self.random_init,
232
- df_sharding=(None, ('model', 'expert')),
233
- fd_sharding=(('model', 'expert'), None))
265
+ df_sharding=(None, ShardingAxisName.MLP_TENSOR),
266
+ fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
234
267
  else:
235
268
  custom_module = MoE(
236
269
  dtype=dtype,
@@ -241,10 +274,10 @@ class DeepSeekV3(nnx.Module):
241
274
  hidden_act=hidden_act,
242
275
  rngs=self.rng,
243
276
  random_init=self.random_init,
244
- activation_ffw_td=('data', None),
245
- activation_ffw_ted=('data', None, None),
246
- edf_sharding=('model', None, None),
247
- efd_sharding=('model', None, None),
277
+ activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
278
+ activation_ffw_ted=(ShardingAxisName.MLP_DATA, None, None),
279
+ edf_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
280
+ efd_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
248
281
  router=router) if is_moe_layer else DenseFFW(
249
282
  dtype=dtype,
250
283
  hidden_act=hidden_act,
@@ -252,18 +285,18 @@ class DeepSeekV3(nnx.Module):
252
285
  intermediate_size=ffw_intermediate_size,
253
286
  rngs=self.rng,
254
287
  random_init=self.random_init,
255
- df_sharding=(None, ('model', 'expert')),
256
- fd_sharding=(('model', 'expert'), None))
257
-
258
- shared_experts = DenseFFW(dtype=dtype,
259
- hidden_act=hidden_act,
260
- hidden_size=hidden_size,
261
- intermediate_size=num_shared_experts *
262
- moe_intermediate_size,
263
- rngs=self.rng,
264
- random_init=self.random_init,
265
- df_sharding=(None, ('model', 'expert')),
266
- fd_sharding=(('model', 'expert'), None))
288
+ df_sharding=(None, ShardingAxisName.MLP_TENSOR),
289
+ fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
290
+
291
+ shared_experts = DenseFFW(
292
+ dtype=dtype,
293
+ hidden_act=hidden_act,
294
+ hidden_size=hidden_size,
295
+ intermediate_size=num_shared_experts * moe_intermediate_size,
296
+ rngs=self.rng,
297
+ random_init=self.random_init,
298
+ df_sharding=(None, ShardingAxisName.MLP_TENSOR),
299
+ fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
267
300
 
268
301
  pre_attention_norm = RMSNorm(
269
302
  dims=hidden_size,
@@ -304,10 +337,28 @@ class DeepSeekV3(nnx.Module):
304
337
  hidden_size=hidden_size,
305
338
  dtype=dtype,
306
339
  rngs=self.rng,
307
- vd_sharding=(('data', 'expert', 'model'), None),
308
- dv_sharding=(None, ('data', 'expert', 'model')),
340
+ vd_sharding=(ShardingAxisName.MLP_TENSOR, None),
341
+ dv_sharding=(None, ShardingAxisName.MLP_TENSOR),
309
342
  random_init=self.random_init)
310
343
 
344
+ if os.environ.get("VLLM_LOGGING_LEVEL", "").upper() == "DEBUG":
345
+ self._print_model_architecture()
346
+
347
+ def _print_model_architecture(self):
348
+ num_display_layers = 5
349
+
350
+ logger.debug("### Embedding ###")
351
+ nnx.display(self.embedder)
352
+
353
+ logger.debug(f"\n### First {num_display_layers} Layers ###")
354
+ # Loop through the slice and display each layer
355
+ for i, layer in enumerate(self.layers[:num_display_layers]):
356
+ logger.debug(f"\n--- Layer {i} ---")
357
+ nnx.display(layer)
358
+
359
+ logger.debug("\n### LM Head ###")
360
+ nnx.display(self.lm_head)
361
+
311
362
  # For compatibility with flax.
312
363
  def apply(self, variables, *args, **kwargs):
313
364
  return self.__call__(*args, **kwargs)
@@ -352,10 +403,19 @@ class DeepSeekV3(nnx.Module):
352
403
  @dataclass
353
404
  class DeepSeekV3WeightLoader:
354
405
 
355
- def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size,
356
- q_lora_rank, kv_lora_rank, attn_heads, qk_nope_head_dim,
357
- qk_rope_head_dim, v_head_dim, num_local_experts, model_dtype):
358
-
406
+ def __init__(self,
407
+ vllm_config: VllmConfig,
408
+ num_layers,
409
+ hidden_size,
410
+ q_lora_rank,
411
+ kv_lora_rank,
412
+ attn_heads,
413
+ qk_nope_head_dim,
414
+ qk_rope_head_dim,
415
+ v_head_dim,
416
+ num_local_experts,
417
+ model_dtype,
418
+ use_mla_kernel=False):
359
419
  self.num_layers = num_layers
360
420
  self.names_and_weights_generator = model_weights_generator(
361
421
  model_name_or_path=vllm_config.model_config.model,
@@ -364,7 +424,12 @@ class DeepSeekV3WeightLoader:
364
424
  self.is_verbose = vllm_config.additional_config.get(
365
425
  "is_verbose", None) is not None
366
426
  self.num_routed_experts = num_local_experts
427
+ self.attn_heads = attn_heads
428
+ self.qk_nope_head_dim = qk_nope_head_dim
429
+ self.v_head_dim = v_head_dim
430
+ self.kv_lora_rank = kv_lora_rank
367
431
  self.model_dtype = model_dtype
432
+ self.use_mla_kernel = use_mla_kernel
368
433
 
369
434
  self._transpose_map = {
370
435
  # dense mlp
@@ -373,10 +438,12 @@ class DeepSeekV3WeightLoader:
373
438
  r"mlp\.up_proj": (1, 0),
374
439
  # mla
375
440
  r"q_a_proj": (1, 0),
376
- r"q_b_proj": (2, 0, 1),
441
+ r"q_b_proj": (1, 0),
377
442
  r"kv_a_proj_with_mqa": (1, 0),
378
- r"kv_b_proj": (2, 0, 1),
379
- r"o_proj": (1, 2, 0),
443
+ r"kv_b_proj": (1, 0),
444
+ r"k_b_proj": (2, 0, 1), # used for MLA kernel
445
+ r"v_b_proj": (2, 0, 1), # used for MLA kernel
446
+ r"o_proj": (1, 0),
380
447
  # moe
381
448
  r"mlp\.gate\.weight": (1, 0),
382
449
  r"mlp\.experts\.\d+\.gate_proj": (0, 2, 1),
@@ -388,13 +455,6 @@ class DeepSeekV3WeightLoader:
388
455
  # lm_head
389
456
  r"lm_head\.weight": (1, 0)
390
457
  }
391
- self._weight_shape_map = {
392
- "q_b_proj":
393
- (attn_heads, qk_nope_head_dim + qk_rope_head_dim, q_lora_rank),
394
- "kv_b_proj":
395
- (attn_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank),
396
- "o_proj": (hidden_size, attn_heads, v_head_dim)
397
- }
398
458
 
399
459
  # Set the mappings from loaded parameter keys to standardized names.
400
460
  self._loaded_to_standardized_keys = {
@@ -419,13 +479,13 @@ class DeepSeekV3WeightLoader:
419
479
  "model.layers.*.self_attn.q_a_proj.weight":
420
480
  "layers.*.attn.kernel_q_down_proj_DA",
421
481
  "model.layers.*.self_attn.q_b_proj.weight":
422
- "layers.*.attn.kernel_q_up_proj_ANH",
482
+ "layers.*.attn.kernel_q_up_proj_AP",
423
483
  "model.layers.*.self_attn.kv_a_proj_with_mqa.weight":
424
484
  "layers.*.attn.kernel_kv_down_proj_DA",
425
485
  "model.layers.*.self_attn.kv_b_proj.weight":
426
- "layers.*.attn.kernel_kv_up_proj_ANH",
486
+ "layers.*.attn.kernel_kv_up_proj_AL",
427
487
  "model.layers.*.self_attn.o_proj.weight":
428
- "layers.*.attn.kernel_o_proj_NHD",
488
+ "layers.*.attn.kernel_o_proj_RD",
429
489
  # Dense ffw
430
490
  "model.layers.*.mlp.gate_proj.weight":
431
491
  "layers.*.custom_module.kernel_gating_DF",
@@ -452,57 +512,50 @@ class DeepSeekV3WeightLoader:
452
512
  "model.layers.*.mlp.shared_experts.up_proj.weight":
453
513
  "layers.*.shared_experts.kernel_up_proj_DF",
454
514
  }
455
-
456
- # TODO (jacobplatin): we shouldn't hard-code this, but the logic to obtain the true quantized dtype
457
- # is non-trivial and the default checkpoints all use this dtype
458
- self.quant_dtype = jnp.float8_e4m3fn
515
+ if self.use_mla_kernel:
516
+ self._loaded_to_standardized_keys.update({
517
+ "model.layers.*.self_attn.k_b_proj.weight":
518
+ "layers.*.attn.kernel_k_up_proj_ANH",
519
+ "model.layers.*.self_attn.v_b_proj.weight":
520
+ "layers.*.attn.kernel_v_up_proj_ANH",
521
+ })
522
+ # TODO (jacobplatin): we should not be hard-coding these
523
+ self.scale_dtype, self.quant_dtype = jnp.bfloat16, jnp.float8_e4m3fn
459
524
 
460
525
  self.is_model_quantized = not vllm_config.additional_config.get(
461
526
  "skip_quantization", False)
462
- if self.is_model_quantized:
463
- # TODO (jacobplatin): expand support eventually
464
- quantization_type = vllm_config.model_config.hf_config.quantization_config[
465
- "quant_method"]
466
- assert quantization_type == "fp8", "DeepSeek only supports the fp8 quantization method for now"
467
- self.scale_dtype, self.quant_dtype = get_quant_dtype_from_qwix_config(
468
- vllm_config)
469
-
470
- logger.info(
471
- f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
472
- )
473
527
 
474
- quantization_block_sizes = vllm_config.model_config.hf_config.quantization_config[
475
- "weight_block_size"]
476
- assert len(
477
- quantization_block_sizes
478
- ) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
479
- self.quantization_block_size_n = quantization_block_sizes[0]
480
- self.quantization_block_size_k = quantization_block_sizes[1]
481
- # TODO (jacobplatin): remove this check in the future
482
- assert self.quantization_block_size_n == self.quantization_block_size_k, "Quantization block size n and k must be the same!"
483
- # NOTE: this is only needed for pre-quantized models
484
- self._scale_shape_map = {
485
- "q_b_proj": (1, qk_nope_head_dim + qk_rope_head_dim,
486
- q_lora_rank // self.quantization_block_size_n),
487
- "kv_b_proj": (attn_heads, (qk_nope_head_dim + v_head_dim) //
488
- self.quantization_block_size_n,
489
- kv_lora_rank // self.quantization_block_size_n),
490
- "o_proj":
491
- (hidden_size // self.quantization_block_size_n, attn_heads,
492
- v_head_dim // self.quantization_block_size_n),
493
- }
528
+ if self.is_model_quantized:
494
529
  # NOTE: this is only needed for pre-quantized models when doing random weight loading
530
+ # because the scales that Qwix configures by default don't necessarily match the
531
+ # scales in practice
495
532
  # TODO (jacobplatin): remove or clean this up
496
- self.scale_shap_map_for_random_weight_loading = {
497
- "kernel_kv_down_proj_DA": (56, 576),
498
- "kernel_kv_up_proj_ANH": (4, 128, 2),
499
- "kernel_q_up_proj_ANH": (12, 1, 192),
500
- "kernel_o_proj_NHD": (128, 1, 56),
501
- "kernel_down_proj_EFD": (256, 16, 56),
502
- "kernel_up_proj_EDF": (256, 56, 16),
503
- "kernel_gating_EDF": (256, 56, 16),
533
+ self.scale_shape_map_for_random_weight_loading = {
534
+ # MoE experts (3D)
535
+ "custom_module.kernel_down_proj_EFD": (256, 8, 7168),
536
+ "custom_module.kernel_gating_EDF": (256, 28, 2048),
537
+ "custom_module.kernel_up_proj_EDF": (256, 28, 2048),
538
+ # Shared experts (2D)
539
+ "shared_experts.kernel_down_proj_FD": (8, 7168),
540
+ "shared_experts.kernel_gating_DF": (28, 2048),
541
+ "shared_experts.kernel_up_proj_DF": (28, 2048),
542
+ # Dense FFW (2D)
543
+ "custom_module.kernel_gating_DF": (28, 18432),
544
+ "custom_module.kernel_up_proj_DF": (28, 18432),
545
+ "custom_module.kernel_down_proj_FD": (72, 7168),
546
+ # Attention (3D for MLA, 2D for the rest)
547
+ "attn.kernel_q_down_proj_DA": (28, 1536),
548
+ "attn.kernel_q_up_proj_AP": (6, 24576),
549
+ "attn.kernel_kv_down_proj_DA": (28, 576),
550
+ "attn.kernel_kv_up_proj_AL": (2, 32768),
551
+ "attn.kernel_o_proj_RD": (64, 7168),
552
+ "attn.kernel_k_up_proj_ANH": (2, 128, 128), # MLA
553
+ "attn.kernel_v_up_proj_ANH": (2, 128, 128), # MLA
504
554
  }
505
555
 
556
+ # TODO (jacobplatin): remove this check eventually!
557
+ assert self.quant_dtype == jnp.float8_e4m3fn, f"Expected quant_dtype to be float8_e4m3fn for DeepSeek but got {self.quant_dtype}"
558
+
506
559
  def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
507
560
  # Find the corresponding model key using the HF key
508
561
  if "layer" in loaded_key:
@@ -580,45 +633,56 @@ class DeepSeekV3WeightLoader:
580
633
  base_model_weight, "array") else base_model_weight.sharding
581
634
 
582
635
  # Convert weights from torch into numpy
583
- cast_type = model_weight.value.dtype
584
-
585
- torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
586
-
587
- if torch_view_type:
588
- # Avoid unnecessary upcasting and mem copy by viewing the tensor's
589
- # raw data as integers before converting to a JAX array.
590
- weight_np = jnp.array(
591
- weight.view(torch_view_type).numpy()).view(cast_type)
636
+ if weight.dtype == torch.uint8 and scale is not None:
637
+ # Assume packed FP4 format when uint8 weights with scale provided
638
+ weight_jax_u8 = jnp.array(weight.cpu().numpy())
639
+ weight_np = u8_unpack_e2m1(weight_jax_u8)
640
+ scale = scale.to(torch.float32).numpy().astype(self.scale_dtype)
592
641
  else:
593
- raise ValueError(
594
- f"Unsupported dtype for tensor conversion: {cast_type}")
642
+ cast_type = model_weight.value.dtype
643
+ # Special-case: FP4 values stored as FP8 for compatibility.
644
+ # If the model expects float4_e2m1fn but the checkpoint provides FP8,
645
+ # convert by numeric value (float32) then cast to float4.
646
+ if cast_type == jnp.float4_e2m1fn and weight.dtype == torch.float8_e4m3fn:
647
+ weight_np = jnp.array(weight.float().numpy()).astype(cast_type)
648
+ else:
649
+ torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
595
650
 
596
- if scale is not None:
597
- scale = scale.to(torch.float32).numpy().astype(self.scale_dtype)
651
+ if torch_view_type:
652
+ # Avoid unnecessary upcasting and mem copy by viewing the tensor's
653
+ # raw data as integers before converting to a JAX array.
654
+ weight_np = jnp.array(
655
+ weight.view(torch_view_type).numpy()).view(cast_type)
656
+ else:
657
+ raise ValueError(
658
+ f"Unsupported dtype for tensor conversion: {cast_type}"
659
+ )
598
660
 
599
- # Reshape and transpose weights if necessary.
600
- weight_np = reshape_params(name, weight_np, self._weight_shape_map)
601
- if scale is not None:
602
- scale = reshape_params(name, scale, self._scale_shape_map)
661
+ if scale is not None:
662
+ scale = scale.to(torch.float32).numpy().astype(
663
+ self.scale_dtype)
603
664
  weight_np = self._transpose_params(name, weight_np)
604
665
  if scale is not None:
605
666
  scale = self._transpose_params(name, scale)
667
+ # Ensure scale is broadcastable to weight_np by repeating per-axis.
606
668
  weight_shape = weight_np.shape
607
669
  scale_shape = scale.shape
608
- assert len(weight_shape) == len(scale_shape)
609
- for idx, (weight_dim,
610
- scale_dim) in enumerate(zip(weight_shape, scale_shape)):
611
- if weight_dim // self.quantization_block_size_n != scale_dim and weight_dim // scale_dim != 1:
612
- old_scale_shape = scale.shape
613
- scale = scale.repeat(self.quantization_block_size_n,
614
- axis=idx)[:, :weight_dim]
670
+ if len(weight_shape) == len(scale_shape):
671
+ new_scale = scale
672
+ for wdim, sdim in zip(weight_shape, scale_shape):
673
+ if (wdim % sdim != 0):
674
+ raise ValueError(
675
+ f"Weight dim {wdim} is not divisible by scale dim {sdim} for weight {name} with shape {weight_shape} and scale {scale_shape}!"
676
+ )
677
+ if scale_shape != new_scale.shape:
615
678
  logger.warning(
616
- f"Got a weight with shape {weight_shape} and scale with shape {old_scale_shape} "
617
- f"where the scale_dim {scale_dim} does not match the weight_dim {weight_dim} "
618
- f"multiplied by the quantization block size {self.quantization_block_size_n}. "
619
- f"Repeating the scale to new shape {scale.shape} along axis {idx} with repeat size {self.quantization_block_size_n}."
679
+ f"Adjusted scale shape {scale_shape} to {new_scale.shape} to match weight {weight_shape}"
620
680
  )
621
- break
681
+ scale = new_scale
682
+ else:
683
+ raise ValueError(
684
+ f"Scale rank {scale_shape} does not match weight rank {weight_shape}"
685
+ )
622
686
 
623
687
  if model_weight.value.shape != weight_np.shape:
624
688
  raise ValueError(
@@ -652,10 +716,8 @@ class DeepSeekV3WeightLoader:
652
716
  logger.warning(
653
717
  f"Could not create sharded scale for {name} with shape {scale.shape} and sharding {sharding}, skipping sharding..."
654
718
  )
655
- # NOTE: Despite the fact that scale has the name `scale_inv` in it, we don't need to
656
- # inverse it
657
- assert base_model_weight.array.scale.value.dtype == maybe_sharded_scale.dtype, "Expected dtype for model weight scale with name {mapped_name} and dtype ({base_model_weight.array.scale.value.dtype}) to match that of the incoming weight scale ({maybe_sharded_scale.dtype})"
658
- assert base_model_weight.array.qvalue.value.dtype == sharded_array.dtype, "Expected dtype for model weight with name {mapped_name} and dtype ({base_model_weight.array.qvalue.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
719
+ assert base_model_weight.array.scale.value.dtype == maybe_sharded_scale.dtype, f"Expected dtype for model weight scale with name {mapped_name} and dtype ({base_model_weight.array.scale.value.dtype}) to match that of the incoming weight scale ({maybe_sharded_scale.dtype})"
720
+ assert base_model_weight.array.qvalue.value.dtype == sharded_array.dtype, f"Expected dtype for model weight with name {mapped_name} and dtype ({base_model_weight.array.qvalue.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
659
721
  base_model_weight.array.scale.value = maybe_sharded_scale
660
722
  base_model_weight.array.qvalue.value = sharded_array
661
723
  else:
@@ -721,7 +783,11 @@ class DeepSeekV3WeightLoader:
721
783
  # TODO (jacobplatin): refactor this so that we instead change / update `model_weights_generator`
722
784
  # instead of checking "weight_scale_inv" and assuming quantization method is fp8
723
785
  scale = None
724
- if loaded_weight.dtype == j2t_dtype(self.quant_dtype.dtype):
786
+ # Mixed quantization: accept both fp8 and packed fp4 (uint8) tensors
787
+ allowed_quant_dtypes = {
788
+ j2t_dtype(self.quant_dtype.dtype), torch.uint8
789
+ }
790
+ if loaded_weight.dtype in allowed_quant_dtypes:
725
791
  if self.is_model_quantized:
726
792
  scale_name = loaded_name.replace(
727
793
  ".weight", ".weight_scale_inv")
@@ -802,21 +868,65 @@ class DeepSeekV3WeightLoader:
802
868
  f"Cumulative local memory: {cumulative_local_memory} GB"
803
869
  )
804
870
  else:
805
- weight_bytes, weight_shards = self._load_individual_weight(
806
- loaded_name,
807
- loaded_weight,
808
- model_params,
809
- model_for_loading.mesh,
810
- scale=scale)
811
- if self.is_verbose:
812
- cumulative_global_memory += weight_bytes
813
- cumulative_local_memory += weight_shards
814
- logger.info(
815
- f"Cumulative global memory: {cumulative_global_memory} GB"
816
- )
817
- logger.info(
818
- f"Cumulative local memory: {cumulative_local_memory} GB"
819
- )
871
+ if self.use_mla_kernel and "kv_b_proj" in loaded_name:
872
+ # loaded_weight shape: (num_heads * (d_k + d_v), kv_lora_rank)
873
+ # scale shape: (num_heads * (d_k + d_v) / block_n, kv_lora_rank / block_k)
874
+ # Reshape to (num_heads, (d_k + d_v), kv_lora_rank) and split
875
+ weight_reshaped = loaded_weight.view(
876
+ self.attn_heads,
877
+ self.qk_nope_head_dim + self.v_head_dim,
878
+ self.kv_lora_rank)
879
+ k_weight = weight_reshaped[:, :self.
880
+ qk_nope_head_dim, :]
881
+ v_weight = weight_reshaped[:,
882
+ self.qk_nope_head_dim:, :]
883
+
884
+ loaded_weights_list = [k_weight, v_weight]
885
+ loaded_names = [
886
+ loaded_name.replace("kv_b_proj", "k_b_proj"),
887
+ loaded_name.replace("kv_b_proj", "v_b_proj")
888
+ ]
889
+
890
+ scales_list = [None, None]
891
+ if scale is not None:
892
+ assert loaded_weight.shape[0] == scale.shape[0]
893
+ block_size_k = loaded_weight.shape[
894
+ 1] // scale.shape[1]
895
+ assert block_size_k > 0, f"Expected non-zero block size but got {block_size_k}!"
896
+ scale_reshaped = scale.view(
897
+ self.attn_heads,
898
+ (self.qk_nope_head_dim + self.v_head_dim),
899
+ self.kv_lora_rank // block_size_k)
900
+
901
+ k_scale = scale_reshaped[:, :self.
902
+ qk_nope_head_dim, :]
903
+ v_scale = scale_reshaped[:,
904
+ self.qk_nope_head_dim:, :]
905
+ scales_list = [k_scale, v_scale]
906
+
907
+ else:
908
+ loaded_weights_list = [loaded_weight]
909
+ loaded_names = [loaded_name]
910
+ scales_list = [scale]
911
+
912
+ for loaded_name, loaded_weight, scale in zip(
913
+ loaded_names, loaded_weights_list, scales_list):
914
+
915
+ weight_bytes, weight_shards = self._load_individual_weight(
916
+ loaded_name,
917
+ loaded_weight,
918
+ model_params,
919
+ model_for_loading.mesh,
920
+ scale=scale)
921
+ if self.is_verbose:
922
+ cumulative_global_memory += weight_bytes
923
+ cumulative_local_memory += weight_shards
924
+ logger.info(
925
+ f"Cumulative global memory: {cumulative_global_memory} GB"
926
+ )
927
+ logger.info(
928
+ f"Cumulative local memory: {cumulative_local_memory} GB"
929
+ )
820
930
 
821
931
  del mlp_experts_gate_proj_weights
822
932
  del mlp_experts_up_proj_weights